diff --git a/Cargo.toml b/Cargo.toml index 49b8322..b0f4b25 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = [ "rtc-datachannel", "rtc-dtls", "rtc", + "rtc-turn", "rtc-rtcp", "rtc-rtp", "rtc-sctp", @@ -14,7 +15,6 @@ members = [ "reserved/rtc-interceptor", "reserved/rtc-mdns", "reserved/rtc-media", - "reserved/rtc-turn", ] resolver = "2" diff --git a/reserved/ice/.gitignore b/reserved/ice/.gitignore new file mode 100644 index 0000000..81561ed --- /dev/null +++ b/reserved/ice/.gitignore @@ -0,0 +1,11 @@ +# Generated by Cargo +# will have compiled files and executables +/target/ +/.idea/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk diff --git a/reserved/ice/CHANGELOG.md b/reserved/ice/CHANGELOG.md new file mode 100644 index 0000000..07f16c5 --- /dev/null +++ b/reserved/ice/CHANGELOG.md @@ -0,0 +1,49 @@ +# webrtc-ice changelog + +## Unreleased + +### Breaking changes + +* remove non used `MulticastDnsMode::Unspecified` variant [#404](https://github.com/webrtc-rs/webrtc/pull/404): + +## v0.9.0 + +* Increased minimum support rust version to `1.60.0`. + +### Breaking changes + +* Make functions non-async [#338](https://github.com/webrtc-rs/webrtc/pull/338): + - `Agent`: + - `get_bytes_received`; + - `get_bytes_sent`; + - `on_connection_state_change`; + - `on_selected_candidate_pair_change`; + - `on_candidate`; + - `add_remote_candidate`; + - `gather_candidates`. + - `unmarshal_candidate`; + - `CandidateHostConfig::new_candidate_host`; + - `CandidatePeerReflexiveConfig::new_candidate_peer_reflexive`; + - `CandidateRelayConfig::new_candidate_relay`; + - `CandidateServerReflexiveConfig::new_candidate_server_reflexive`; + - `Candidate`: + - `addr`; + - `set_ip`. + +## v0.8.2 + +* Add IP filter to ICE `AgentConfig` [#306](https://github.com/webrtc-rs/webrtc/pull/306) and [#318](https://github.com/webrtc-rs/webrtc/pull/318). +* Add `rust-version` at 1.57.0 to `Cargo.toml`. This was already the minimum version so does not constitute a change. + +## v0.8.1 + +This release was released in error and contains no changes from 0.8.0. + +## v0.8.0 + +* Increased min verison of `log` dependency to `0.4.16`. [#250 Fix log at ^0.4.16 to make tests compile](https://github.com/webrtc-rs/webrtc/pull/250) by [@k0nserv](https://github.com/k0nserv). +* Incresed serde's minimum version to 1.0.102 [#243 Fixes for cargo minimal-versions](https://github.com/webrtc-rs/webrtc/pull/243) contributed by [algesten](https://github.com/algesten) + +## Prior to 0.8.0 + +Before 0.8.0 there was no changelog, previous changes are sometimes, but not always, available in the [GitHub Releases](https://github.com/webrtc-rs/ice/releases). diff --git a/reserved/ice/Cargo.toml b/reserved/ice/Cargo.toml new file mode 100644 index 0000000..db10d84 --- /dev/null +++ b/reserved/ice/Cargo.toml @@ -0,0 +1,42 @@ +[package] +name = "ice" +version = "0.0.0" +authors = ["Rain Liu "] +edition = "2021" +description = "ICE in Rust" +license = "MIT/Apache-2.0" + +[dependencies] +util = { version = "0.7.0", package = "webrtc-util", default-features = false, features = ["conn", "vnet", "sync"] } +turn = { version = "0.6.1"} +stun = { version = "0.4.3"} +mdns = { version = "0.5.0", package = "webrtc-mdns" } + +arc-swap = "1.5" +async-trait = "0.1.56" +crc = "3.0" +log = "0.4.16" +rand = "0.8.5" +serde = { version = "1.0.102", features = ["derive"] } +serde_json = "1.0" +thiserror = "1.0" +tokio = { version = "1.19", features = ["full"] } +url = "2.2" +uuid = { version = "1.1", features = ["v4"] } +waitgroup = "0.1.2" + +[dev-dependencies] +tokio-test = "0.4.0" # must match the min version of the `tokio` crate above +regex = "1" +env_logger = "0.9.0" +chrono = "0.4.23" +ipnet = "2.5.0" +clap = "3.2.6" +lazy_static = "1.4.0" +hyper = { version = "0.14.19", features = ["full"] } +sha1 = "0.10.5" + +[[example]] +name = "ping_pong" +path = "examples/ping_pong.rs" +bench = false diff --git a/reserved/ice/LICENSE-APACHE b/reserved/ice/LICENSE-APACHE new file mode 100644 index 0000000..16fe87b --- /dev/null +++ b/reserved/ice/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/reserved/ice/LICENSE-MIT b/reserved/ice/LICENSE-MIT new file mode 100644 index 0000000..e11d93b --- /dev/null +++ b/reserved/ice/LICENSE-MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 WebRTC.rs + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/reserved/ice/README.md b/reserved/ice/README.md new file mode 100644 index 0000000..aa28286 --- /dev/null +++ b/reserved/ice/README.md @@ -0,0 +1,30 @@ +

+ WebRTC.rs +
+

+

+ + + + + + + + + + + + + + + + + License: MIT/Apache 2.0 + + + Discord + +

+

+ A pure Rust implementation of ICE. Rewrite Pion ICE in Rust +

diff --git a/reserved/ice/codecov.yml b/reserved/ice/codecov.yml new file mode 100644 index 0000000..99d83b7 --- /dev/null +++ b/reserved/ice/codecov.yml @@ -0,0 +1,23 @@ +codecov: + require_ci_to_pass: yes + max_report_age: off + token: 4bd5cec1-2807-4cd6-8430-d5f3efe32ce0 + +coverage: + precision: 2 + round: down + range: 50..90 + status: + project: + default: + enabled: no + threshold: 0.2 + if_not_found: success + patch: + default: + enabled: no + if_not_found: success + changes: + default: + enabled: no + if_not_found: success diff --git a/reserved/ice/doc/webrtc.rs.png b/reserved/ice/doc/webrtc.rs.png new file mode 100644 index 0000000..7bf0dda Binary files /dev/null and b/reserved/ice/doc/webrtc.rs.png differ diff --git a/reserved/ice/examples/ping_pong.rs b/reserved/ice/examples/ping_pong.rs new file mode 100644 index 0000000..143d864 --- /dev/null +++ b/reserved/ice/examples/ping_pong.rs @@ -0,0 +1,423 @@ +use std::io; +use std::sync::Arc; +use std::time::Duration; + +use clap::{App, AppSettings, Arg}; +use hyper::service::{make_service_fn, service_fn}; +use hyper::{Body, Client, Method, Request, Response, Server, StatusCode}; +use ice::agent::agent_config::AgentConfig; +use ice::agent::Agent; +use ice::candidate::candidate_base::*; +use ice::candidate::*; +use ice::network_type::*; +use ice::state::*; +use ice::udp_network::UDPNetwork; +use ice::Error; +use rand::{thread_rng, Rng}; +use tokio::net::UdpSocket; +use tokio::sync::{mpsc, watch, Mutex}; +use util::Conn; + +#[macro_use] +extern crate lazy_static; + +type SenderType = Arc>>; +type ReceiverType = Arc>>; + +lazy_static! { + // ErrUnknownType indicates an error with Unknown info. + static ref REMOTE_AUTH_CHANNEL: (SenderType, ReceiverType ) = { + let (tx, rx) = mpsc::channel::(3); + (Arc::new(Mutex::new(tx)), Arc::new(Mutex::new(rx))) + }; + + static ref REMOTE_CAND_CHANNEL: (SenderType, ReceiverType) = { + let (tx, rx) = mpsc::channel::(10); + (Arc::new(Mutex::new(tx)), Arc::new(Mutex::new(rx))) + }; +} + +// HTTP Listener to get ICE Credentials/Candidate from remote Peer +async fn remote_handler(req: Request) -> Result, hyper::Error> { + //println!("received {:?}", req); + match (req.method(), req.uri().path()) { + (&Method::POST, "/remoteAuth") => { + let full_body = + match std::str::from_utf8(&hyper::body::to_bytes(req.into_body()).await?) { + Ok(s) => s.to_owned(), + Err(err) => panic!("{}", err), + }; + let tx = REMOTE_AUTH_CHANNEL.0.lock().await; + //println!("body: {:?}", full_body); + let _ = tx.send(full_body).await; + + let mut response = Response::new(Body::empty()); + *response.status_mut() = StatusCode::OK; + Ok(response) + } + + (&Method::POST, "/remoteCandidate") => { + let full_body = + match std::str::from_utf8(&hyper::body::to_bytes(req.into_body()).await?) { + Ok(s) => s.to_owned(), + Err(err) => panic!("{}", err), + }; + let tx = REMOTE_CAND_CHANNEL.0.lock().await; + //println!("body: {:?}", full_body); + let _ = tx.send(full_body).await; + + let mut response = Response::new(Body::empty()); + *response.status_mut() = StatusCode::OK; + Ok(response) + } + + // Return the 404 Not Found for other routes. + _ => { + let mut not_found = Response::default(); + *not_found.status_mut() = StatusCode::NOT_FOUND; + Ok(not_found) + } + } +} + +// Controlled Agent: +// cargo run --color=always --package webrtc-ice --example ping_pong +// Controlling Agent: +// cargo run --color=always --package webrtc-ice --example ping_pong -- --controlling + +#[tokio::main] +async fn main() -> Result<(), Error> { + env_logger::init(); + // .format(|buf, record| { + // writeln!( + // buf, + // "{}:{} [{}] {} - {}", + // record.file().unwrap_or("unknown"), + // record.line().unwrap_or(0), + // record.level(), + // chrono::Local::now().format("%H:%M:%S.%6f"), + // record.args() + // ) + // }) + // .filter(None, log::LevelFilter::Trace) + // .init(); + + let mut app = App::new("ICE Demo") + .version("0.1.0") + .author("Rain Liu ") + .about("An example of ICE") + .setting(AppSettings::DeriveDisplayOrder) + .setting(AppSettings::SubcommandsNegateReqs) + .arg( + Arg::with_name("use-mux") + .takes_value(false) + .long("use-mux") + .short('m') + .help("Use a muxed UDP connection over a single listening port"), + ) + .arg( + Arg::with_name("FULLHELP") + .help("Prints more detailed help information") + .long("fullhelp"), + ) + .arg( + Arg::with_name("controlling") + .takes_value(false) + .long("controlling") + .help("is ICE Agent controlling"), + ); + + let matches = app.clone().get_matches(); + + if matches.is_present("FULLHELP") { + app.print_long_help().unwrap(); + std::process::exit(0); + } + + let is_controlling = matches.is_present("controlling"); + let use_mux = matches.is_present("use-mux"); + + let (local_http_port, remote_http_port) = if is_controlling { + (9000, 9001) + } else { + (9001, 9000) + }; + + let (weak_conn, weak_agent) = { + let (done_tx, done_rx) = watch::channel(()); + + println!("Listening on http://localhost:{local_http_port}"); + let mut done_http_server = done_rx.clone(); + tokio::spawn(async move { + let addr = ([0, 0, 0, 0], local_http_port).into(); + let service = + make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(remote_handler)) }); + let server = Server::bind(&addr).serve(service); + tokio::select! { + _ = done_http_server.changed() => { + println!("receive cancel http server!"); + } + result = server => { + // Run this server for... forever! + if let Err(e) = result { + eprintln!("server error: {e}"); + } + println!("exit http server!"); + } + }; + }); + + if is_controlling { + println!("Local Agent is controlling"); + } else { + println!("Local Agent is controlled"); + }; + println!("Press 'Enter' when both processes have started"); + let mut input = String::new(); + let _ = io::stdin().read_line(&mut input)?; + + let udp_network = if use_mux { + use ice::udp_mux::*; + let port = if is_controlling { 4000 } else { 4001 }; + + let udp_socket = UdpSocket::bind(("0.0.0.0", port)).await?; + let udp_mux = UDPMuxDefault::new(UDPMuxParams::new(udp_socket)); + + UDPNetwork::Muxed(udp_mux) + } else { + UDPNetwork::Ephemeral(Default::default()) + }; + + let ice_agent = Arc::new( + Agent::new(AgentConfig { + network_types: vec![NetworkType::Udp4], + udp_network, + ..Default::default() + }) + .await?, + ); + + let client = Arc::new(Client::new()); + + // When we have gathered a new ICE Candidate send it to the remote peer + let client2 = Arc::clone(&client); + ice_agent.on_candidate(Box::new( + move |c: Option>| { + let client3 = Arc::clone(&client2); + Box::pin(async move { + if let Some(c) = c { + println!("posting remoteCandidate with {}", c.marshal()); + + let req = match Request::builder() + .method(Method::POST) + .uri(format!( + "http://localhost:{remote_http_port}/remoteCandidate" + )) + .body(Body::from(c.marshal())) + { + Ok(req) => req, + Err(err) => { + println!("{err}"); + return; + } + }; + let resp = match client3.request(req).await { + Ok(resp) => resp, + Err(err) => { + println!("{err}"); + return; + } + }; + println!("Response from remoteCandidate: {}", resp.status()); + } + }) + }, + )); + + let (ice_done_tx, mut ice_done_rx) = mpsc::channel::<()>(1); + // When ICE Connection state has change print to stdout + ice_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { + println!("ICE Connection State has changed: {c}"); + if c == ConnectionState::Failed { + let _ = ice_done_tx.try_send(()); + } + Box::pin(async move {}) + })); + + // Get the local auth details and send to remote peer + let (local_ufrag, local_pwd) = ice_agent.get_local_user_credentials().await; + + println!("posting remoteAuth with {local_ufrag}:{local_pwd}"); + let req = match Request::builder() + .method(Method::POST) + .uri(format!("http://localhost:{remote_http_port}/remoteAuth")) + .body(Body::from(format!("{local_ufrag}:{local_pwd}"))) + { + Ok(req) => req, + Err(err) => return Err(Error::Other(format!("{err}"))), + }; + let resp = match client.request(req).await { + Ok(resp) => resp, + Err(err) => return Err(Error::Other(format!("{err}"))), + }; + println!("Response from remoteAuth: {}", resp.status()); + + let (remote_ufrag, remote_pwd) = { + let mut rx = REMOTE_AUTH_CHANNEL.1.lock().await; + if let Some(s) = rx.recv().await { + println!("received: {s}"); + let fields: Vec = s.split(':').map(|s| s.to_string()).collect(); + (fields[0].clone(), fields[1].clone()) + } else { + panic!("rx.recv() empty"); + } + }; + println!("remote_ufrag: {remote_ufrag}, remote_pwd: {remote_pwd}"); + + let ice_agent2 = Arc::clone(&ice_agent); + let mut done_cand = done_rx.clone(); + tokio::spawn(async move { + let mut rx = REMOTE_CAND_CHANNEL.1.lock().await; + loop { + tokio::select! { + _ = done_cand.changed() => { + println!("receive cancel remote cand!"); + break; + } + result = rx.recv() => { + if let Some(s) = result { + if let Ok(c) = unmarshal_candidate(&s) { + println!("add_remote_candidate: {c}"); + let c: Arc = Arc::new(c); + let _ = ice_agent2.add_remote_candidate(&c); + }else{ + println!("unmarshal_candidate error!"); + break; + } + }else{ + println!("REMOTE_CAND_CHANNEL done!"); + break; + } + } + }; + } + }); + + ice_agent.gather_candidates()?; + println!("Connecting..."); + + let (_cancel_tx, cancel_rx) = mpsc::channel(1); + // Start the ICE Agent. One side must be controlled, and the other must be controlling + let conn: Arc = if is_controlling { + ice_agent.dial(cancel_rx, remote_ufrag, remote_pwd).await? + } else { + ice_agent + .accept(cancel_rx, remote_ufrag, remote_pwd) + .await? + }; + + let weak_conn = Arc::downgrade(&conn); + + // Send messages in a loop to the remote peer + let conn_tx = Arc::clone(&conn); + let mut done_send = done_rx.clone(); + tokio::spawn(async move { + const RANDOM_STRING: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + loop { + tokio::time::sleep(Duration::from_secs(3)).await; + + let val: String = (0..15) + .map(|_| { + let idx = thread_rng().gen_range(0..RANDOM_STRING.len()); + RANDOM_STRING[idx] as char + }) + .collect(); + + tokio::select! { + _ = done_send.changed() => { + println!("receive cancel ice send!"); + break; + } + result = conn_tx.send(val.as_bytes()) => { + if let Err(err) = result { + eprintln!("conn_tx send error: {err}"); + break; + }else{ + println!("Sent: '{val}'"); + } + } + }; + } + }); + + let mut done_recv = done_rx.clone(); + tokio::spawn(async move { + // Receive messages in a loop from the remote peer + let mut buf = vec![0u8; 1500]; + loop { + tokio::select! { + _ = done_recv.changed() => { + println!("receive cancel ice recv!"); + break; + } + result = conn.recv(&mut buf) => { + match result { + Ok(n) => { + println!("Received: '{}'", std::str::from_utf8(&buf[..n]).unwrap()); + } + Err(err) => { + eprintln!("conn_tx send error: {err}"); + break; + } + }; + } + }; + } + }); + + println!("Press ctrl-c to stop"); + /*let d = if is_controlling { + Duration::from_secs(500) + } else { + Duration::from_secs(5) + }; + let timeout = tokio::time::sleep(d); + tokio::pin!(timeout);*/ + + tokio::select! { + /*_ = timeout.as_mut() => { + println!("received timeout signal!"); + let _ = done_tx.send(()); + }*/ + _ = ice_done_rx.recv() => { + println!("ice_done_rx"); + let _ = done_tx.send(()); + } + _ = tokio::signal::ctrl_c() => { + println!(); + let _ = done_tx.send(()); + } + }; + + let _ = ice_agent.close().await; + + (weak_conn, Arc::downgrade(&ice_agent)) + }; + + let mut int = tokio::time::interval(Duration::from_secs(1)); + loop { + int.tick().await; + println!( + "weak_conn: weak count = {}, strong count = {}, weak_agent: weak count = {}, strong count = {}", + weak_conn.weak_count(), + weak_conn.strong_count(), + weak_agent.weak_count(), + weak_agent.strong_count(), + ); + if weak_conn.strong_count() == 0 && weak_agent.strong_count() == 0 { + break; + } + } + + Ok(()) +} diff --git a/reserved/ice/src/agent/agent_config.rs b/reserved/ice/src/agent/agent_config.rs new file mode 100644 index 0000000..4de93a8 --- /dev/null +++ b/reserved/ice/src/agent/agent_config.rs @@ -0,0 +1,255 @@ +use std::net::IpAddr; +use std::time::Duration; + +use util::vnet::net::*; + +use super::*; +use crate::error::*; +use crate::mdns::*; +use crate::network_type::*; +use crate::udp_network::UDPNetwork; +use crate::url::*; + +/// The interval at which the agent performs candidate checks in the connecting phase. +pub(crate) const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_millis(200); + +/// The interval used to keep candidates alive. +pub(crate) const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(2); + +/// The default time till an Agent transitions disconnected. +pub(crate) const DEFAULT_DISCONNECTED_TIMEOUT: Duration = Duration::from_secs(5); + +/// The default time till an Agent transitions to failed after disconnected. +pub(crate) const DEFAULT_FAILED_TIMEOUT: Duration = Duration::from_secs(25); + +/// Wait time before nominating a host candidate. +pub(crate) const DEFAULT_HOST_ACCEPTANCE_MIN_WAIT: Duration = Duration::from_secs(0); + +/// Wait time before nominating a srflx candidate. +pub(crate) const DEFAULT_SRFLX_ACCEPTANCE_MIN_WAIT: Duration = Duration::from_millis(500); + +/// Wait time before nominating a prflx candidate. +pub(crate) const DEFAULT_PRFLX_ACCEPTANCE_MIN_WAIT: Duration = Duration::from_millis(1000); + +/// Wait time before nominating a relay candidate. +pub(crate) const DEFAULT_RELAY_ACCEPTANCE_MIN_WAIT: Duration = Duration::from_millis(2000); + +/// Max binding request before considering a pair failed. +pub(crate) const DEFAULT_MAX_BINDING_REQUESTS: u16 = 7; + +/// The number of bytes that can be buffered before we start to error. +pub(crate) const MAX_BUFFER_SIZE: usize = 1000 * 1000; // 1MB + +/// Wait time before binding requests can be deleted. +pub(crate) const MAX_BINDING_REQUEST_TIMEOUT: Duration = Duration::from_millis(4000); + +pub(crate) fn default_candidate_types() -> Vec { + vec![ + CandidateType::Host, + CandidateType::ServerReflexive, + CandidateType::Relay, + ] +} + +pub type InterfaceFilterFn = Box bool) + Send + Sync>; +pub type IpFilterFn = Box bool) + Send + Sync>; + +/// Collects the arguments to `ice::Agent` construction into a single structure, for +/// future-proofness of the interface. +#[derive(Default)] +pub struct AgentConfig { + pub urls: Vec, + + /// Controls how the UDP network stack works. + /// See [`UDPNetwork`] + pub udp_network: UDPNetwork, + + /// It is used to perform connectivity checks. The values MUST be unguessable, with at least + /// 128 bits of random number generator output used to generate the password, and at least 24 + /// bits of output to generate the username fragment. + pub local_ufrag: String, + /// It is used to perform connectivity checks. The values MUST be unguessable, with at least + /// 128 bits of random number generator output used to generate the password, and at least 24 + /// bits of output to generate the username fragment. + pub local_pwd: String, + + /// Controls mDNS behavior for the ICE agent. + pub multicast_dns_mode: MulticastDnsMode, + + /// Controls the hostname for this agent. If none is specified a random one will be generated. + pub multicast_dns_host_name: String, + + /// Control mDNS destination address + pub multicast_dns_dest_addr: String, + + /// Defaults to 5 seconds when this property is nil. + /// If the duration is 0, the ICE Agent will never go to disconnected. + pub disconnected_timeout: Option, + + /// Defaults to 25 seconds when this property is nil. + /// If the duration is 0, we will never go to failed. + pub failed_timeout: Option, + + /// Determines how often should we send ICE keepalives (should be less then connectiontimeout + /// above) when this is nil, it defaults to 10 seconds. + /// A keepalive interval of 0 means we never send keepalive packets + pub keepalive_interval: Option, + + /// An optional configuration for disabling or enabling support for specific network types. + pub network_types: Vec, + + /// An optional configuration for disabling or enabling support for specific candidate types. + pub candidate_types: Vec, + + //LoggerFactory logging.LoggerFactory + /// Controls how often our internal task loop runs when in the connecting state. + /// Only useful for testing. + pub check_interval: Duration, + + /// The max amount of binding requests the agent will send over a candidate pair for validation + /// or nomination, if after max_binding_requests the candidate is yet to answer a binding + /// request or a nomination we set the pair as failed. + pub max_binding_requests: Option, + + pub is_controlling: bool, + + /// lite agents do not perform connectivity check and only provide host candidates. + pub lite: bool, + + /// It is used along with nat1to1ips to specify which candidate type the 1:1 NAT IP addresses + /// should be mapped to. If unspecified or CandidateTypeHost, nat1to1ips are used to replace + /// host candidate IPs. If CandidateTypeServerReflexive, it will insert a srflx candidate (as + /// if it was dervied from a STUN server) with its port number being the one for the actual host + /// candidate. Other values will result in an error. + pub nat_1to1_ip_candidate_type: CandidateType, + + /// Contains a list of public IP addresses that are to be used as a host candidate or srflx + /// candidate. This is used typically for servers that are behind 1:1 D-NAT (e.g. AWS EC2 + /// instances) and to eliminate the need of server reflexisive candidate gathering. + pub nat_1to1_ips: Vec, + + /// Specify a minimum wait time before selecting host candidates. + pub host_acceptance_min_wait: Option, + /// Specify a minimum wait time before selecting srflx candidates. + pub srflx_acceptance_min_wait: Option, + /// Specify a minimum wait time before selecting prflx candidates. + pub prflx_acceptance_min_wait: Option, + /// Specify a minimum wait time before selecting relay candidates. + pub relay_acceptance_min_wait: Option, + + /// Net is the our abstracted network interface for internal development purpose only + /// (see (github.com/pion/transport/vnet)[github.com/pion/transport/vnet]). + pub net: Option>, + + /// A function that you can use in order to whitelist or blacklist the interfaces which are + /// used to gather ICE candidates. + pub interface_filter: Arc>, + + /// A function that you can use in order to whitelist or blacklist + /// the ips which are used to gather ICE candidates. + pub ip_filter: Arc>, + + /// Controls if self-signed certificates are accepted when connecting to TURN servers via TLS or + /// DTLS. + pub insecure_skip_verify: bool, +} + +impl AgentConfig { + /// Populates an agent and falls back to defaults if fields are unset. + pub(crate) fn init_with_defaults(&self, a: &mut AgentInternal) { + if let Some(max_binding_requests) = self.max_binding_requests { + a.max_binding_requests = max_binding_requests; + } else { + a.max_binding_requests = DEFAULT_MAX_BINDING_REQUESTS; + } + + if let Some(host_acceptance_min_wait) = self.host_acceptance_min_wait { + a.host_acceptance_min_wait = host_acceptance_min_wait; + } else { + a.host_acceptance_min_wait = DEFAULT_HOST_ACCEPTANCE_MIN_WAIT; + } + + if let Some(srflx_acceptance_min_wait) = self.srflx_acceptance_min_wait { + a.srflx_acceptance_min_wait = srflx_acceptance_min_wait; + } else { + a.srflx_acceptance_min_wait = DEFAULT_SRFLX_ACCEPTANCE_MIN_WAIT; + } + + if let Some(prflx_acceptance_min_wait) = self.prflx_acceptance_min_wait { + a.prflx_acceptance_min_wait = prflx_acceptance_min_wait; + } else { + a.prflx_acceptance_min_wait = DEFAULT_PRFLX_ACCEPTANCE_MIN_WAIT; + } + + if let Some(relay_acceptance_min_wait) = self.relay_acceptance_min_wait { + a.relay_acceptance_min_wait = relay_acceptance_min_wait; + } else { + a.relay_acceptance_min_wait = DEFAULT_RELAY_ACCEPTANCE_MIN_WAIT; + } + + if let Some(disconnected_timeout) = self.disconnected_timeout { + a.disconnected_timeout = disconnected_timeout; + } else { + a.disconnected_timeout = DEFAULT_DISCONNECTED_TIMEOUT; + } + + if let Some(failed_timeout) = self.failed_timeout { + a.failed_timeout = failed_timeout; + } else { + a.failed_timeout = DEFAULT_FAILED_TIMEOUT; + } + + if let Some(keepalive_interval) = self.keepalive_interval { + a.keepalive_interval = keepalive_interval; + } else { + a.keepalive_interval = DEFAULT_KEEPALIVE_INTERVAL; + } + + if self.check_interval == Duration::from_secs(0) { + a.check_interval = DEFAULT_CHECK_INTERVAL; + } else { + a.check_interval = self.check_interval; + } + } + + pub(crate) fn init_ext_ip_mapping( + &self, + mdns_mode: MulticastDnsMode, + candidate_types: &[CandidateType], + ) -> Result> { + if let Some(ext_ip_mapper) = + ExternalIpMapper::new(self.nat_1to1_ip_candidate_type, &self.nat_1to1_ips)? + { + if ext_ip_mapper.candidate_type == CandidateType::Host { + if mdns_mode == MulticastDnsMode::QueryAndGather { + return Err(Error::ErrMulticastDnsWithNat1to1IpMapping); + } + let mut candi_host_enabled = false; + for candi_type in candidate_types { + if *candi_type == CandidateType::Host { + candi_host_enabled = true; + break; + } + } + if !candi_host_enabled { + return Err(Error::ErrIneffectiveNat1to1IpMappingHost); + } + } else if ext_ip_mapper.candidate_type == CandidateType::ServerReflexive { + let mut candi_srflx_enabled = false; + for candi_type in candidate_types { + if *candi_type == CandidateType::ServerReflexive { + candi_srflx_enabled = true; + break; + } + } + if !candi_srflx_enabled { + return Err(Error::ErrIneffectiveNat1to1IpMappingSrflx); + } + } + + Ok(Some(ext_ip_mapper)) + } else { + Ok(None) + } + } +} diff --git a/reserved/ice/src/agent/agent_gather.rs b/reserved/ice/src/agent/agent_gather.rs new file mode 100644 index 0000000..7cf75be --- /dev/null +++ b/reserved/ice/src/agent/agent_gather.rs @@ -0,0 +1,888 @@ +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::str::FromStr; +use std::sync::Arc; + +use util::vnet::net::*; +use util::Conn; +use waitgroup::WaitGroup; + +use super::*; +use crate::candidate::candidate_base::CandidateBaseConfig; +use crate::candidate::candidate_host::CandidateHostConfig; +use crate::candidate::candidate_relay::CandidateRelayConfig; +use crate::candidate::candidate_server_reflexive::CandidateServerReflexiveConfig; +use crate::candidate::*; +use crate::error::*; +use crate::network_type::*; +use crate::udp_network::UDPNetwork; +use crate::url::{ProtoType, SchemeType, Url}; +use crate::util::*; + +const STUN_GATHER_TIMEOUT: Duration = Duration::from_secs(5); + +pub(crate) struct GatherCandidatesInternalParams { + pub(crate) udp_network: UDPNetwork, + pub(crate) candidate_types: Vec, + pub(crate) urls: Vec, + pub(crate) network_types: Vec, + pub(crate) mdns_mode: MulticastDnsMode, + pub(crate) mdns_name: String, + pub(crate) net: Arc, + pub(crate) interface_filter: Arc>, + pub(crate) ip_filter: Arc>, + pub(crate) ext_ip_mapper: Arc>, + pub(crate) agent_internal: Arc, + pub(crate) gathering_state: Arc, + pub(crate) chan_candidate_tx: ChanCandidateTx, +} + +struct GatherCandidatesLocalParams { + udp_network: UDPNetwork, + network_types: Vec, + mdns_mode: MulticastDnsMode, + mdns_name: String, + interface_filter: Arc>, + ip_filter: Arc>, + ext_ip_mapper: Arc>, + net: Arc, + agent_internal: Arc, +} + +struct GatherCandidatesLocalUDPMuxParams { + network_types: Vec, + interface_filter: Arc>, + ip_filter: Arc>, + ext_ip_mapper: Arc>, + net: Arc, + agent_internal: Arc, + udp_mux: Arc, +} + +struct GatherCandidatesSrflxMappedParasm { + network_types: Vec, + port_max: u16, + port_min: u16, + ext_ip_mapper: Arc>, + net: Arc, + agent_internal: Arc, +} + +struct GatherCandidatesSrflxParams { + urls: Vec, + network_types: Vec, + port_max: u16, + port_min: u16, + net: Arc, + agent_internal: Arc, +} + +impl Agent { + pub(crate) async fn gather_candidates_internal(params: GatherCandidatesInternalParams) { + Self::set_gathering_state( + ¶ms.chan_candidate_tx, + ¶ms.gathering_state, + GatheringState::Gathering, + ) + .await; + + let wg = WaitGroup::new(); + + for t in ¶ms.candidate_types { + match t { + CandidateType::Host => { + let local_params = GatherCandidatesLocalParams { + udp_network: params.udp_network.clone(), + network_types: params.network_types.clone(), + mdns_mode: params.mdns_mode, + mdns_name: params.mdns_name.clone(), + interface_filter: Arc::clone(¶ms.interface_filter), + ip_filter: Arc::clone(¶ms.ip_filter), + ext_ip_mapper: Arc::clone(¶ms.ext_ip_mapper), + net: Arc::clone(¶ms.net), + agent_internal: Arc::clone(¶ms.agent_internal), + }; + + let w = wg.worker(); + tokio::spawn(async move { + let _d = w; + + Self::gather_candidates_local(local_params).await; + }); + } + CandidateType::ServerReflexive => { + let ephemeral_config = match ¶ms.udp_network { + UDPNetwork::Ephemeral(e) => e, + // No server reflexive for muxxed connections + UDPNetwork::Muxed(_) => continue, + }; + + let srflx_params = GatherCandidatesSrflxParams { + urls: params.urls.clone(), + network_types: params.network_types.clone(), + port_max: ephemeral_config.port_max(), + port_min: ephemeral_config.port_min(), + net: Arc::clone(¶ms.net), + agent_internal: Arc::clone(¶ms.agent_internal), + }; + let w1 = wg.worker(); + tokio::spawn(async move { + let _d = w1; + + Self::gather_candidates_srflx(srflx_params).await; + }); + if let Some(ext_ip_mapper) = &*params.ext_ip_mapper { + if ext_ip_mapper.candidate_type == CandidateType::ServerReflexive { + let srflx_mapped_params = GatherCandidatesSrflxMappedParasm { + network_types: params.network_types.clone(), + port_max: ephemeral_config.port_max(), + port_min: ephemeral_config.port_min(), + ext_ip_mapper: Arc::clone(¶ms.ext_ip_mapper), + net: Arc::clone(¶ms.net), + agent_internal: Arc::clone(¶ms.agent_internal), + }; + let w2 = wg.worker(); + tokio::spawn(async move { + let _d = w2; + + Self::gather_candidates_srflx_mapped(srflx_mapped_params).await; + }); + } + } + } + CandidateType::Relay => { + let urls = params.urls.clone(); + let net = Arc::clone(¶ms.net); + let agent_internal = Arc::clone(¶ms.agent_internal); + let w = wg.worker(); + tokio::spawn(async move { + let _d = w; + + Self::gather_candidates_relay(urls, net, agent_internal).await; + }); + } + _ => {} + } + } + + // Block until all STUN and TURN URLs have been gathered (or timed out) + wg.wait().await; + + Self::set_gathering_state( + ¶ms.chan_candidate_tx, + ¶ms.gathering_state, + GatheringState::Complete, + ) + .await; + } + + async fn set_gathering_state( + chan_candidate_tx: &ChanCandidateTx, + gathering_state: &Arc, + new_state: GatheringState, + ) { + if GatheringState::from(gathering_state.load(Ordering::SeqCst)) != new_state + && new_state == GatheringState::Complete + { + let cand_tx = chan_candidate_tx.lock().await; + if let Some(tx) = &*cand_tx { + let _ = tx.send(None).await; + } + } + + gathering_state.store(new_state as u8, Ordering::SeqCst); + } + + async fn gather_candidates_local(params: GatherCandidatesLocalParams) { + let GatherCandidatesLocalParams { + udp_network, + network_types, + mdns_mode, + mdns_name, + interface_filter, + ip_filter, + ext_ip_mapper, + net, + agent_internal, + } = params; + + // If we wanna use UDP mux, do so + // FIXME: We still need to support TCP in combination with this option + if let UDPNetwork::Muxed(udp_mux) = udp_network { + let result = Self::gather_candidates_local_udp_mux(GatherCandidatesLocalUDPMuxParams { + network_types, + interface_filter, + ip_filter, + ext_ip_mapper, + net, + agent_internal, + udp_mux, + }) + .await; + + if let Err(err) = result { + log::error!("Failed to gather local candidates using UDP mux: {}", err); + } + + return; + } + + let ips = local_interfaces(&net, &interface_filter, &ip_filter, &network_types).await; + for ip in ips { + let mut mapped_ip = ip; + + if mdns_mode != MulticastDnsMode::QueryAndGather && ext_ip_mapper.is_some() { + if let Some(ext_ip_mapper2) = ext_ip_mapper.as_ref() { + if ext_ip_mapper2.candidate_type == CandidateType::Host { + if let Ok(mi) = ext_ip_mapper2.find_external_ip(&ip.to_string()) { + mapped_ip = mi; + } else { + log::warn!( + "[{}]: 1:1 NAT mapping is enabled but no external IP is found for {}", + agent_internal.get_name(), + ip + ); + } + } + } + } + + let address = if mdns_mode == MulticastDnsMode::QueryAndGather { + mdns_name.clone() + } else { + mapped_ip.to_string() + }; + + //TODO: for network in networks + let network = UDP.to_owned(); + if let UDPNetwork::Ephemeral(ephemeral_config) = &udp_network { + /*TODO:switch network { + case tcp: + // Handle ICE TCP passive mode + + a.log.Debugf("GetConn by ufrag: %s\n", a.localUfrag) + conn, err = a.tcpMux.GetConnByUfrag(a.localUfrag) + if err != nil { + if !errors.Is(err, ErrTCPMuxNotInitialized) { + a.log.Warnf("error getting tcp conn by ufrag: %s %s %s\n", network, ip, a.localUfrag) + } + continue + } + port = conn.LocalAddr().(*net.TCPAddr).Port + tcpType = TCPTypePassive + // is there a way to verify that the listen address is even + // accessible from the current interface. + case udp:*/ + + let conn: Arc = match listen_udp_in_port_range( + &net, + ephemeral_config.port_max(), + ephemeral_config.port_min(), + SocketAddr::new(ip, 0), + ) + .await + { + Ok(conn) => conn, + Err(err) => { + log::warn!( + "[{}]: could not listen {} {}: {}", + agent_internal.get_name(), + network, + ip, + err + ); + continue; + } + }; + + let port = match conn.local_addr() { + Ok(addr) => addr.port(), + Err(err) => { + log::warn!( + "[{}]: could not get local addr: {}", + agent_internal.get_name(), + err + ); + continue; + } + }; + + let host_config = CandidateHostConfig { + base_config: CandidateBaseConfig { + network: network.clone(), + address, + port, + component: COMPONENT_RTP, + conn: Some(conn), + ..CandidateBaseConfig::default() + }, + ..CandidateHostConfig::default() + }; + + let candidate: Arc = + match host_config.new_candidate_host() { + Ok(candidate) => { + if mdns_mode == MulticastDnsMode::QueryAndGather { + if let Err(err) = candidate.set_ip(&ip) { + log::warn!( + "[{}]: Failed to create host candidate: {} {} {}: {:?}", + agent_internal.get_name(), + network, + mapped_ip, + port, + err + ); + continue; + } + } + Arc::new(candidate) + } + Err(err) => { + log::warn!( + "[{}]: Failed to create host candidate: {} {} {}: {}", + agent_internal.get_name(), + network, + mapped_ip, + port, + err + ); + continue; + } + }; + + { + if let Err(err) = agent_internal.add_candidate(&candidate).await { + if let Err(close_err) = candidate.close().await { + log::warn!( + "[{}]: Failed to close candidate: {}", + agent_internal.get_name(), + close_err + ); + } + log::warn!( + "[{}]: Failed to append to localCandidates and run onCandidateHdlr: {}", + agent_internal.get_name(), + err + ); + } + } + } + } + } + + async fn gather_candidates_local_udp_mux( + params: GatherCandidatesLocalUDPMuxParams, + ) -> Result<()> { + let GatherCandidatesLocalUDPMuxParams { + network_types, + interface_filter, + ip_filter, + ext_ip_mapper, + net, + agent_internal, + udp_mux, + } = params; + + // Filter out non UDP network types + let relevant_network_types: Vec<_> = + network_types.into_iter().filter(|n| n.is_udp()).collect(); + + let udp_mux = Arc::clone(&udp_mux); + + // There's actually only one, but `local_interfaces` requires a slice. + let local_ips = + local_interfaces(&net, &interface_filter, &ip_filter, &relevant_network_types).await; + + let candidate_ip = ext_ip_mapper + .as_ref() // Arc + .as_ref() // Option + .and_then(|mapper| { + if mapper.candidate_type != CandidateType::Host { + return None; + } + + local_ips + .iter() + .find_map(|ip| match mapper.find_external_ip(&ip.to_string()) { + Ok(ip) => Some(ip), + Err(err) => { + log::warn!( + "1:1 NAT mapping is enabled but not external IP is found for {}: {}", + ip, + err + ); + None + } + }) + }) + .or_else(|| local_ips.iter().copied().next()); + + let candidate_ip = match candidate_ip { + None => return Err(Error::ErrCandidateIpNotFound), + Some(ip) => ip, + }; + + let ufrag = { + let ufrag_pwd = agent_internal.ufrag_pwd.lock().await; + + ufrag_pwd.local_ufrag.clone() + }; + + let conn = udp_mux.get_conn(&ufrag).await?; + let port = conn.local_addr()?.port(); + + let host_config = CandidateHostConfig { + base_config: CandidateBaseConfig { + network: UDP.to_owned(), + address: candidate_ip.to_string(), + port, + conn: Some(conn), + component: COMPONENT_RTP, + ..Default::default() + }, + tcp_type: TcpType::Unspecified, + }; + + let candidate: Arc = + Arc::new(host_config.new_candidate_host()?); + + agent_internal.add_candidate(&candidate).await?; + + Ok(()) + } + + async fn gather_candidates_srflx_mapped(params: GatherCandidatesSrflxMappedParasm) { + let GatherCandidatesSrflxMappedParasm { + network_types, + port_max, + port_min, + ext_ip_mapper, + net, + agent_internal, + } = params; + + let wg = WaitGroup::new(); + + for network_type in network_types { + if network_type.is_tcp() { + continue; + } + + let network = network_type.to_string(); + let net2 = Arc::clone(&net); + let agent_internal2 = Arc::clone(&agent_internal); + let ext_ip_mapper2 = Arc::clone(&ext_ip_mapper); + + let w = wg.worker(); + tokio::spawn(async move { + let _d = w; + + let conn: Arc = match listen_udp_in_port_range( + &net2, + port_max, + port_min, + if network_type.is_ipv4() { + SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0) + } else { + SocketAddr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(), 0) + }, + ) + .await + { + Ok(conn) => conn, + Err(err) => { + log::warn!( + "[{}]: Failed to listen {}: {}", + agent_internal2.get_name(), + network, + err + ); + return Ok(()); + } + }; + + let laddr = conn.local_addr()?; + let mapped_ip = { + if let Some(ext_ip_mapper3) = &*ext_ip_mapper2 { + match ext_ip_mapper3.find_external_ip(&laddr.ip().to_string()) { + Ok(ip) => ip, + Err(err) => { + log::warn!( + "[{}]: 1:1 NAT mapping is enabled but no external IP is found for {}: {}", + agent_internal2.get_name(), + laddr, + err + ); + return Ok(()); + } + } + } else { + log::error!( + "[{}]: ext_ip_mapper is None in gather_candidates_srflx_mapped", + agent_internal2.get_name(), + ); + return Ok(()); + } + }; + + let srflx_config = CandidateServerReflexiveConfig { + base_config: CandidateBaseConfig { + network: network.clone(), + address: mapped_ip.to_string(), + port: laddr.port(), + component: COMPONENT_RTP, + conn: Some(conn), + ..CandidateBaseConfig::default() + }, + rel_addr: laddr.ip().to_string(), + rel_port: laddr.port(), + }; + + let candidate: Arc = + match srflx_config.new_candidate_server_reflexive() { + Ok(candidate) => Arc::new(candidate), + Err(err) => { + log::warn!( + "[{}]: Failed to create server reflexive candidate: {} {} {}: {}", + agent_internal2.get_name(), + network, + mapped_ip, + laddr.port(), + err + ); + return Ok(()); + } + }; + + { + if let Err(err) = agent_internal2.add_candidate(&candidate).await { + if let Err(close_err) = candidate.close().await { + log::warn!( + "[{}]: Failed to close candidate: {}", + agent_internal2.get_name(), + close_err + ); + } + log::warn!( + "[{}]: Failed to append to localCandidates and run onCandidateHdlr: {}", + agent_internal2.get_name(), + err + ); + } + } + + Result::<()>::Ok(()) + }); + } + + wg.wait().await; + } + + async fn gather_candidates_srflx(params: GatherCandidatesSrflxParams) { + let GatherCandidatesSrflxParams { + urls, + network_types, + port_max, + port_min, + net, + agent_internal, + } = params; + + let wg = WaitGroup::new(); + for network_type in network_types { + if network_type.is_tcp() { + continue; + } + + for url in &urls { + let network = network_type.to_string(); + let is_ipv4 = network_type.is_ipv4(); + let url = url.clone(); + let net2 = Arc::clone(&net); + let agent_internal2 = Arc::clone(&agent_internal); + + let w = wg.worker(); + tokio::spawn(async move { + let _d = w; + + let host_port = format!("{}:{}", url.host, url.port); + let server_addr = match net2.resolve_addr(is_ipv4, &host_port).await { + Ok(addr) => addr, + Err(err) => { + log::warn!( + "[{}]: failed to resolve stun host: {}: {}", + agent_internal2.get_name(), + host_port, + err + ); + return Ok(()); + } + }; + + let conn: Arc = match listen_udp_in_port_range( + &net2, + port_max, + port_min, + if is_ipv4 { + SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0) + } else { + SocketAddr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(), 0) + }, + ) + .await + { + Ok(conn) => conn, + Err(err) => { + log::warn!( + "[{}]: Failed to listen for {}: {}", + agent_internal2.get_name(), + server_addr, + err + ); + return Ok(()); + } + }; + + let xoraddr = + match get_xormapped_addr(&conn, server_addr, STUN_GATHER_TIMEOUT).await { + Ok(xoraddr) => xoraddr, + Err(err) => { + log::warn!( + "[{}]: could not get server reflexive address {} {}: {}", + agent_internal2.get_name(), + network, + url, + err + ); + return Ok(()); + } + }; + + let (ip, port) = (xoraddr.ip, xoraddr.port); + + let laddr = conn.local_addr()?; + let srflx_config = CandidateServerReflexiveConfig { + base_config: CandidateBaseConfig { + network: network.clone(), + address: ip.to_string(), + port, + component: COMPONENT_RTP, + conn: Some(conn), + ..CandidateBaseConfig::default() + }, + rel_addr: laddr.ip().to_string(), + rel_port: laddr.port(), + }; + + let candidate: Arc = + match srflx_config.new_candidate_server_reflexive() { + Ok(candidate) => Arc::new(candidate), + Err(err) => { + log::warn!( + "[{}]: Failed to create server reflexive candidate: {} {} {}: {:?}", + agent_internal2.get_name(), + network, + ip, + port, + err + ); + return Ok(()); + } + }; + + { + if let Err(err) = agent_internal2.add_candidate(&candidate).await { + if let Err(close_err) = candidate.close().await { + log::warn!( + "[{}]: Failed to close candidate: {}", + agent_internal2.get_name(), + close_err + ); + } + log::warn!( + "[{}]: Failed to append to localCandidates and run onCandidateHdlr: {}", + agent_internal2.get_name(), + err + ); + } + } + + Result::<()>::Ok(()) + }); + } + } + + wg.wait().await; + } + + pub(crate) async fn gather_candidates_relay( + urls: Vec, + net: Arc, + agent_internal: Arc, + ) { + let wg = WaitGroup::new(); + + for url in urls { + if url.scheme != SchemeType::Turn && url.scheme != SchemeType::Turns { + continue; + } + if url.username.is_empty() { + log::error!( + "[{}]:Failed to gather relay candidates: {:?}", + agent_internal.get_name(), + Error::ErrUsernameEmpty + ); + return; + } + if url.password.is_empty() { + log::error!( + "[{}]: Failed to gather relay candidates: {:?}", + agent_internal.get_name(), + Error::ErrPasswordEmpty + ); + return; + } + + let network = NetworkType::Udp4.to_string(); + let net2 = Arc::clone(&net); + let agent_internal2 = Arc::clone(&agent_internal); + + let w = wg.worker(); + tokio::spawn(async move { + let _d = w; + + let turn_server_addr = format!("{}:{}", url.host, url.port); + + let (loc_conn, rel_addr, rel_port) = + if url.proto == ProtoType::Udp && url.scheme == SchemeType::Turn { + let loc_conn = match net2.bind(SocketAddr::from_str("0.0.0.0:0")?).await { + Ok(c) => c, + Err(err) => { + log::warn!( + "[{}]: Failed to listen due to error: {}", + agent_internal2.get_name(), + err + ); + return Ok(()); + } + }; + + let local_addr = loc_conn.local_addr()?; + let rel_addr = local_addr.ip().to_string(); + let rel_port = local_addr.port(); + (loc_conn, rel_addr, rel_port) + /*TODO: case url.proto == ProtoType::UDP && url.scheme == SchemeType::TURNS{ + case a.proxyDialer != nil && url.Proto == ProtoTypeTCP && (url.Scheme == SchemeTypeTURN || url.Scheme == SchemeTypeTURNS): + case url.Proto == ProtoTypeTCP && url.Scheme == SchemeTypeTURN: + case url.Proto == ProtoTypeTCP && url.Scheme == SchemeTypeTURNS:*/ + } else { + log::warn!( + "[{}]: Unable to handle URL in gather_candidates_relay {}", + agent_internal2.get_name(), + url + ); + return Ok(()); + }; + + let cfg = turn::client::ClientConfig { + stun_serv_addr: String::new(), + turn_serv_addr: turn_server_addr.clone(), + username: url.username, + password: url.password, + realm: String::new(), + software: String::new(), + rto_in_ms: 0, + conn: loc_conn, + vnet: Some(Arc::clone(&net2)), + }; + let client = match turn::client::Client::new(cfg).await { + Ok(client) => Arc::new(client), + Err(err) => { + log::warn!( + "[{}]: Failed to build new turn.Client {} {}\n", + agent_internal2.get_name(), + turn_server_addr, + err + ); + return Ok(()); + } + }; + if let Err(err) = client.listen().await { + let _ = client.close().await; + log::warn!( + "[{}]: Failed to listen on turn.Client {} {}", + agent_internal2.get_name(), + turn_server_addr, + err + ); + return Ok(()); + } + + let relay_conn = match client.allocate().await { + Ok(conn) => conn, + Err(err) => { + let _ = client.close().await; + log::warn!( + "[{}]: Failed to allocate on turn.Client {} {}", + agent_internal2.get_name(), + turn_server_addr, + err + ); + return Ok(()); + } + }; + + let raddr = relay_conn.local_addr()?; + let relay_config = CandidateRelayConfig { + base_config: CandidateBaseConfig { + network: network.clone(), + address: raddr.ip().to_string(), + port: raddr.port(), + component: COMPONENT_RTP, + conn: Some(Arc::new(relay_conn)), + ..CandidateBaseConfig::default() + }, + rel_addr, + rel_port, + relay_client: Some(Arc::clone(&client)), + }; + + let candidate: Arc = + match relay_config.new_candidate_relay() { + Ok(candidate) => Arc::new(candidate), + Err(err) => { + let _ = client.close().await; + log::warn!( + "[{}]: Failed to create relay candidate: {} {}: {}", + agent_internal2.get_name(), + network, + raddr, + err + ); + return Ok(()); + } + }; + + { + if let Err(err) = agent_internal2.add_candidate(&candidate).await { + if let Err(close_err) = candidate.close().await { + log::warn!( + "[{}]: Failed to close candidate: {}", + agent_internal2.get_name(), + close_err + ); + } + log::warn!( + "[{}]: Failed to append to localCandidates and run onCandidateHdlr: {}", + agent_internal2.get_name(), + err + ); + } + } + + Result::<()>::Ok(()) + }); + } + + wg.wait().await; + } +} diff --git a/reserved/ice/src/agent/agent_gather_test.rs b/reserved/ice/src/agent/agent_gather_test.rs new file mode 100644 index 0000000..ed90f09 --- /dev/null +++ b/reserved/ice/src/agent/agent_gather_test.rs @@ -0,0 +1,490 @@ +use std::str::FromStr; + +use ipnet::IpNet; +use tokio::net::UdpSocket; +use util::vnet::*; + +use super::agent_vnet_test::*; +use super::*; +use crate::udp_mux::{UDPMuxDefault, UDPMuxParams}; +use crate::util::*; + +#[tokio::test] +async fn test_vnet_gather_no_local_ip_address() -> Result<()> { + let vnet = Arc::new(net::Net::new(Some(net::NetConfig::default()))); + + let a = Agent::new(AgentConfig { + net: Some(Arc::clone(&vnet)), + ..Default::default() + }) + .await?; + + let local_ips = local_interfaces( + &vnet, + &a.interface_filter, + &a.ip_filter, + &[NetworkType::Udp4], + ) + .await; + assert!(local_ips.is_empty(), "should return no local IP"); + + a.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_vnet_gather_dynamic_ip_address() -> Result<()> { + let cider = "1.2.3.0/24"; + let ipnet = IpNet::from_str(cider).map_err(|e| Error::Other(e.to_string()))?; + + let r = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { + cidr: cider.to_owned(), + ..Default::default() + })?)); + let nw = Arc::new(net::Net::new(Some(net::NetConfig::default()))); + connect_net2router(&nw, &r).await?; + + let a = Agent::new(AgentConfig { + net: Some(Arc::clone(&nw)), + ..Default::default() + }) + .await?; + + let local_ips = + local_interfaces(&nw, &a.interface_filter, &a.ip_filter, &[NetworkType::Udp4]).await; + assert!(!local_ips.is_empty(), "should have one local IP"); + + for ip in &local_ips { + if ip.is_loopback() { + panic!("should not return loopback IP"); + } + if !ipnet.contains(ip) { + panic!("{ip} should be contained in the CIDR {ipnet}"); + } + } + + a.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_vnet_gather_listen_udp() -> Result<()> { + let cider = "1.2.3.0/24"; + let r = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { + cidr: cider.to_owned(), + ..Default::default() + })?)); + let nw = Arc::new(net::Net::new(Some(net::NetConfig::default()))); + connect_net2router(&nw, &r).await?; + + let a = Agent::new(AgentConfig { + net: Some(Arc::clone(&nw)), + ..Default::default() + }) + .await?; + + let local_ips = + local_interfaces(&nw, &a.interface_filter, &a.ip_filter, &[NetworkType::Udp4]).await; + assert!(!local_ips.is_empty(), "should have one local IP"); + + for ip in local_ips { + let _ = listen_udp_in_port_range(&nw, 0, 0, SocketAddr::new(ip, 0)).await?; + + let result = listen_udp_in_port_range(&nw, 4999, 5000, SocketAddr::new(ip, 0)).await; + assert!( + result.is_err(), + "listenUDP with invalid port range did not return ErrPort" + ); + + let conn = listen_udp_in_port_range(&nw, 5000, 5000, SocketAddr::new(ip, 0)).await?; + let port = conn.local_addr()?.port(); + assert_eq!( + port, 5000, + "listenUDP with port restriction of 5000 listened on incorrect port ({port})" + ); + } + + a.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_vnet_gather_with_nat_1to1_as_host_candidates() -> Result<()> { + let external_ip0 = "1.2.3.4"; + let external_ip1 = "1.2.3.5"; + let local_ip0 = "10.0.0.1"; + let local_ip1 = "10.0.0.2"; + let map0 = format!("{external_ip0}/{local_ip0}"); + let map1 = format!("{external_ip1}/{local_ip1}"); + + let wan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { + cidr: "1.2.3.0/24".to_owned(), + ..Default::default() + })?)); + + let lan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { + cidr: "10.0.0.0/24".to_owned(), + static_ips: vec![map0.clone(), map1.clone()], + nat_type: Some(nat::NatType { + mode: nat::NatMode::Nat1To1, + ..Default::default() + }), + ..Default::default() + })?)); + + connect_router2router(&lan, &wan).await?; + + let nw = Arc::new(net::Net::new(Some(net::NetConfig { + static_ips: vec![local_ip0.to_owned(), local_ip1.to_owned()], + ..Default::default() + }))); + + connect_net2router(&nw, &lan).await?; + + let a = Agent::new(AgentConfig { + network_types: vec![NetworkType::Udp4], + nat_1to1_ips: vec![map0.clone(), map1.clone()], + net: Some(Arc::clone(&nw)), + ..Default::default() + }) + .await?; + + let (done_tx, mut done_rx) = mpsc::channel::<()>(1); + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + a.on_candidate(Box::new( + move |c: Option>| { + let done_tx_clone = Arc::clone(&done_tx); + Box::pin(async move { + if c.is_none() { + let mut tx = done_tx_clone.lock().await; + tx.take(); + } + }) + }, + )); + + a.gather_candidates()?; + + log::debug!("wait for gathering is done..."); + let _ = done_rx.recv().await; + log::debug!("gathering is done"); + + let candidates = a.get_local_candidates().await?; + assert_eq!(candidates.len(), 2, "There must be two candidates"); + + let mut laddrs = vec![]; + for candi in &candidates { + if let Some(conn) = candi.get_conn() { + let laddr = conn.local_addr()?; + assert_eq!( + candi.port(), + laddr.port(), + "Unexpected candidate port: {}", + candi.port() + ); + laddrs.push(laddr); + } + } + + if candidates[0].address() == external_ip0 { + assert_eq!( + candidates[1].address(), + external_ip1, + "Unexpected candidate IP: {}", + candidates[1].address() + ); + assert_eq!( + laddrs[0].ip().to_string(), + local_ip0, + "Unexpected listen IP: {}", + laddrs[0].ip() + ); + assert_eq!( + laddrs[1].ip().to_string(), + local_ip1, + "Unexpected listen IP: {}", + laddrs[1].ip() + ); + } else if candidates[0].address() == external_ip1 { + assert_eq!( + candidates[1].address(), + external_ip0, + "Unexpected candidate IP: {}", + candidates[1].address() + ); + assert_eq!( + laddrs[0].ip().to_string(), + local_ip1, + "Unexpected listen IP: {}", + laddrs[0].ip(), + ); + assert_eq!( + laddrs[1].ip().to_string(), + local_ip0, + "Unexpected listen IP: {}", + laddrs[1].ip(), + ) + } + + a.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_vnet_gather_with_nat_1to1_as_srflx_candidates() -> Result<()> { + let wan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { + cidr: "1.2.3.0/24".to_owned(), + ..Default::default() + })?)); + + let lan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { + cidr: "10.0.0.0/24".to_owned(), + static_ips: vec!["1.2.3.4/10.0.0.1".to_owned()], + nat_type: Some(nat::NatType { + mode: nat::NatMode::Nat1To1, + ..Default::default() + }), + ..Default::default() + })?)); + + connect_router2router(&lan, &wan).await?; + + let nw = Arc::new(net::Net::new(Some(net::NetConfig { + static_ips: vec!["10.0.0.1".to_owned()], + ..Default::default() + }))); + + connect_net2router(&nw, &lan).await?; + + let a = Agent::new(AgentConfig { + network_types: vec![NetworkType::Udp4], + nat_1to1_ips: vec!["1.2.3.4".to_owned()], + nat_1to1_ip_candidate_type: CandidateType::ServerReflexive, + net: Some(nw), + ..Default::default() + }) + .await?; + + let (done_tx, mut done_rx) = mpsc::channel::<()>(1); + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + a.on_candidate(Box::new( + move |c: Option>| { + let done_tx_clone = Arc::clone(&done_tx); + Box::pin(async move { + if c.is_none() { + let mut tx = done_tx_clone.lock().await; + tx.take(); + } + }) + }, + )); + + a.gather_candidates()?; + + log::debug!("wait for gathering is done..."); + let _ = done_rx.recv().await; + log::debug!("gathering is done"); + + let candidates = a.get_local_candidates().await?; + assert_eq!(candidates.len(), 2, "There must be two candidates"); + + let mut candi_host = None; + let mut candi_srflx = None; + + for candidate in candidates { + match candidate.candidate_type() { + CandidateType::Host => { + candi_host = Some(candidate); + } + CandidateType::ServerReflexive => { + candi_srflx = Some(candidate); + } + _ => { + panic!("Unexpected candidate type"); + } + } + } + + assert!(candi_host.is_some(), "should not be nil"); + assert_eq!("10.0.0.1", candi_host.unwrap().address(), "should match"); + assert!(candi_srflx.is_some(), "should not be nil"); + assert_eq!("1.2.3.4", candi_srflx.unwrap().address(), "should match"); + + a.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_vnet_gather_with_interface_filter() -> Result<()> { + let r = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { + cidr: "1.2.3.0/24".to_owned(), + ..Default::default() + })?)); + let nw = Arc::new(net::Net::new(Some(net::NetConfig::default()))); + connect_net2router(&nw, &r).await?; + + //"InterfaceFilter should exclude the interface" + { + let a = Agent::new(AgentConfig { + net: Some(Arc::clone(&nw)), + interface_filter: Arc::new(Some(Box::new(|_: &str| -> bool { + //assert_eq!("eth0", interface_name); + false + }))), + ..Default::default() + }) + .await?; + + let local_ips = + local_interfaces(&nw, &a.interface_filter, &a.ip_filter, &[NetworkType::Udp4]).await; + assert!( + local_ips.is_empty(), + "InterfaceFilter should have excluded everything" + ); + + a.close().await?; + } + + //"InterfaceFilter should not exclude the interface" + { + let a = Agent::new(AgentConfig { + net: Some(Arc::clone(&nw)), + interface_filter: Arc::new(Some(Box::new(|interface_name: &str| -> bool { + "eth0" == interface_name + }))), + ..Default::default() + }) + .await?; + + let local_ips = + local_interfaces(&nw, &a.interface_filter, &a.ip_filter, &[NetworkType::Udp4]).await; + assert_eq!( + local_ips.len(), + 1, + "InterfaceFilter should not have excluded everything" + ); + + a.close().await?; + } + + Ok(()) +} + +#[tokio::test] +async fn test_vnet_gather_turn_connection_leak() -> Result<()> { + let turn_server_url = Url { + scheme: SchemeType::Turn, + host: VNET_STUN_SERVER_IP.to_owned(), + port: VNET_STUN_SERVER_PORT, + username: "user".to_owned(), + password: "pass".to_owned(), + proto: ProtoType::Udp, + }; + + // buildVNet with a Symmetric NATs for both LANs + let nat_type = nat::NatType { + mapping_behavior: nat::EndpointDependencyType::EndpointAddrPortDependent, + filtering_behavior: nat::EndpointDependencyType::EndpointAddrPortDependent, + ..Default::default() + }; + + let v = build_vnet(nat_type, nat_type).await?; + + let cfg0 = AgentConfig { + urls: vec![turn_server_url.clone()], + network_types: supported_network_types(), + multicast_dns_mode: MulticastDnsMode::Disabled, + nat_1to1_ips: vec![VNET_GLOBAL_IPA.to_owned()], + net: Some(Arc::clone(&v.net0)), + ..Default::default() + }; + + let a_agent = Agent::new(cfg0).await?; + + { + let agent_internal = Arc::clone(&a_agent.internal); + Agent::gather_candidates_relay( + vec![turn_server_url.clone()], + Arc::clone(&v.net0), + agent_internal, + ) + .await; + } + + // Assert relay conn leak on close. + a_agent.close().await?; + v.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_vnet_gather_muxed_udp() -> Result<()> { + let udp_socket = UdpSocket::bind("0.0.0.0:0").await?; + let udp_mux = UDPMuxDefault::new(UDPMuxParams::new(udp_socket)); + + let lan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { + cidr: "10.0.0.0/24".to_owned(), + nat_type: Some(nat::NatType { + mode: nat::NatMode::Nat1To1, + ..Default::default() + }), + ..Default::default() + })?)); + + let nw = Arc::new(net::Net::new(Some(net::NetConfig { + static_ips: vec!["10.0.0.1".to_owned()], + ..Default::default() + }))); + + connect_net2router(&nw, &lan).await?; + + let a = Agent::new(AgentConfig { + network_types: vec![NetworkType::Udp4], + nat_1to1_ips: vec!["1.2.3.4".to_owned()], + net: Some(nw), + udp_network: UDPNetwork::Muxed(udp_mux), + ..Default::default() + }) + .await?; + + let (done_tx, mut done_rx) = mpsc::channel::<()>(1); + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + a.on_candidate(Box::new( + move |c: Option>| { + let done_tx_clone = Arc::clone(&done_tx); + Box::pin(async move { + if c.is_none() { + let mut tx = done_tx_clone.lock().await; + tx.take(); + } + }) + }, + )); + + a.gather_candidates()?; + + log::debug!("wait for gathering is done..."); + let _ = done_rx.recv().await; + log::debug!("gathering is done"); + + let candidates = a.get_local_candidates().await?; + assert_eq!(candidates.len(), 1, "There must be a single candidate"); + + let candi = &candidates[0]; + let laddr = candi.get_conn().unwrap().local_addr()?; + assert_eq!(candi.address(), "1.2.3.4"); + assert_eq!( + candi.port(), + laddr.port(), + "Unexpected candidate port: {}", + candi.port() + ); + + Ok(()) +} diff --git a/reserved/ice/src/agent/agent_internal.rs b/reserved/ice/src/agent/agent_internal.rs new file mode 100644 index 0000000..8fd2b22 --- /dev/null +++ b/reserved/ice/src/agent/agent_internal.rs @@ -0,0 +1,1198 @@ +use std::sync::atomic::{AtomicBool, AtomicU64}; + +use arc_swap::ArcSwapOption; +use util::sync::Mutex as SyncMutex; + +use super::agent_transport::*; +use super::*; +use crate::candidate::candidate_base::CandidateBaseConfig; +use crate::candidate::candidate_peer_reflexive::CandidatePeerReflexiveConfig; +use crate::util::*; + +pub type ChanCandidateTx = + Arc>>>>>; + +#[derive(Default)] +pub(crate) struct UfragPwd { + pub(crate) local_ufrag: String, + pub(crate) local_pwd: String, + pub(crate) remote_ufrag: String, + pub(crate) remote_pwd: String, +} + +pub struct AgentInternal { + // State owned by the taskLoop + pub(crate) on_connected_tx: Mutex>>, + pub(crate) on_connected_rx: Mutex>>, + + // State for closing + pub(crate) done_tx: Mutex>>, + // force candidate to be contacted immediately (instead of waiting for task ticker) + pub(crate) force_candidate_contact_tx: mpsc::Sender, + pub(crate) done_and_force_candidate_contact_rx: + Mutex, mpsc::Receiver)>>, + + pub(crate) chan_candidate_tx: ChanCandidateTx, + pub(crate) chan_candidate_pair_tx: Mutex>>, + pub(crate) chan_state_tx: Mutex>>, + + pub(crate) on_connection_state_change_hdlr: ArcSwapOption>, + pub(crate) on_selected_candidate_pair_change_hdlr: + ArcSwapOption>, + pub(crate) on_candidate_hdlr: ArcSwapOption>, + + pub(crate) tie_breaker: AtomicU64, + pub(crate) is_controlling: AtomicBool, + pub(crate) lite: AtomicBool, + + pub(crate) start_time: SyncMutex, + pub(crate) nominated_pair: Mutex>>, + + pub(crate) connection_state: AtomicU8, //ConnectionState, + + pub(crate) started_ch_tx: Mutex>>, + + pub(crate) ufrag_pwd: Mutex, + + pub(crate) local_candidates: Mutex>>>, + pub(crate) remote_candidates: + Mutex>>>, + + // LRU of outbound Binding request Transaction IDs + pub(crate) pending_binding_requests: Mutex>, + + pub(crate) agent_conn: Arc, + + // the following variables won't be changed after init_with_defaults() + pub(crate) insecure_skip_verify: bool, + pub(crate) max_binding_requests: u16, + pub(crate) host_acceptance_min_wait: Duration, + pub(crate) srflx_acceptance_min_wait: Duration, + pub(crate) prflx_acceptance_min_wait: Duration, + pub(crate) relay_acceptance_min_wait: Duration, + // How long connectivity checks can fail before the ICE Agent + // goes to disconnected + pub(crate) disconnected_timeout: Duration, + // How long connectivity checks can fail before the ICE Agent + // goes to failed + pub(crate) failed_timeout: Duration, + // How often should we send keepalive packets? + // 0 means never + pub(crate) keepalive_interval: Duration, + // How often should we run our internal taskLoop to check for state changes when connecting + pub(crate) check_interval: Duration, +} + +impl AgentInternal { + pub(super) fn new(config: &AgentConfig) -> (Self, ChanReceivers) { + let (chan_state_tx, chan_state_rx) = mpsc::channel(1); + let (chan_candidate_tx, chan_candidate_rx) = mpsc::channel(1); + let (chan_candidate_pair_tx, chan_candidate_pair_rx) = mpsc::channel(1); + let (on_connected_tx, on_connected_rx) = mpsc::channel(1); + let (done_tx, done_rx) = mpsc::channel(1); + let (force_candidate_contact_tx, force_candidate_contact_rx) = mpsc::channel(1); + let (started_ch_tx, _) = broadcast::channel(1); + + let ai = AgentInternal { + on_connected_tx: Mutex::new(Some(on_connected_tx)), + on_connected_rx: Mutex::new(Some(on_connected_rx)), + + done_tx: Mutex::new(Some(done_tx)), + force_candidate_contact_tx, + done_and_force_candidate_contact_rx: Mutex::new(Some(( + done_rx, + force_candidate_contact_rx, + ))), + + chan_candidate_tx: Arc::new(Mutex::new(Some(chan_candidate_tx))), + chan_candidate_pair_tx: Mutex::new(Some(chan_candidate_pair_tx)), + chan_state_tx: Mutex::new(Some(chan_state_tx)), + + on_connection_state_change_hdlr: ArcSwapOption::empty(), + on_selected_candidate_pair_change_hdlr: ArcSwapOption::empty(), + on_candidate_hdlr: ArcSwapOption::empty(), + + tie_breaker: AtomicU64::new(rand::random::()), + is_controlling: AtomicBool::new(config.is_controlling), + lite: AtomicBool::new(config.lite), + + start_time: SyncMutex::new(Instant::now()), + nominated_pair: Mutex::new(None), + + connection_state: AtomicU8::new(ConnectionState::New as u8), + + insecure_skip_verify: config.insecure_skip_verify, + + started_ch_tx: Mutex::new(Some(started_ch_tx)), + + //won't change after init_with_defaults() + max_binding_requests: 0, + host_acceptance_min_wait: Duration::from_secs(0), + srflx_acceptance_min_wait: Duration::from_secs(0), + prflx_acceptance_min_wait: Duration::from_secs(0), + relay_acceptance_min_wait: Duration::from_secs(0), + + // How long connectivity checks can fail before the ICE Agent + // goes to disconnected + disconnected_timeout: Duration::from_secs(0), + + // How long connectivity checks can fail before the ICE Agent + // goes to failed + failed_timeout: Duration::from_secs(0), + + // How often should we send keepalive packets? + // 0 means never + keepalive_interval: Duration::from_secs(0), + + // How often should we run our internal taskLoop to check for state changes when connecting + check_interval: Duration::from_secs(0), + + ufrag_pwd: Mutex::new(UfragPwd::default()), + + local_candidates: Mutex::new(HashMap::new()), + remote_candidates: Mutex::new(HashMap::new()), + + // LRU of outbound Binding request Transaction IDs + pending_binding_requests: Mutex::new(vec![]), + + // AgentConn + agent_conn: Arc::new(AgentConn::new()), + }; + + let chan_receivers = ChanReceivers { + chan_state_rx, + chan_candidate_rx, + chan_candidate_pair_rx, + }; + (ai, chan_receivers) + } + pub(crate) async fn start_connectivity_checks( + self: &Arc, + is_controlling: bool, + remote_ufrag: String, + remote_pwd: String, + ) -> Result<()> { + { + let started_ch_tx = self.started_ch_tx.lock().await; + if started_ch_tx.is_none() { + return Err(Error::ErrMultipleStart); + } + } + + log::debug!( + "Started agent: isControlling? {}, remoteUfrag: {}, remotePwd: {}", + is_controlling, + remote_ufrag, + remote_pwd + ); + self.set_remote_credentials(remote_ufrag, remote_pwd) + .await?; + self.is_controlling.store(is_controlling, Ordering::SeqCst); + self.start().await; + { + let mut started_ch_tx = self.started_ch_tx.lock().await; + started_ch_tx.take(); + } + + self.update_connection_state(ConnectionState::Checking) + .await; + + self.request_connectivity_check(); + + self.connectivity_checks().await; + + Ok(()) + } + + async fn contact( + &self, + last_connection_state: &mut ConnectionState, + checking_duration: &mut Instant, + ) { + if self.connection_state.load(Ordering::SeqCst) == ConnectionState::Failed as u8 { + // The connection is currently failed so don't send any checks + // In the future it may be restarted though + *last_connection_state = self.connection_state.load(Ordering::SeqCst).into(); + return; + } + if self.connection_state.load(Ordering::SeqCst) == ConnectionState::Checking as u8 { + // We have just entered checking for the first time so update our checking timer + if *last_connection_state as u8 != self.connection_state.load(Ordering::SeqCst) { + *checking_duration = Instant::now(); + } + + // We have been in checking longer then Disconnect+Failed timeout, set the connection to Failed + if Instant::now() + .checked_duration_since(*checking_duration) + .unwrap_or_else(|| Duration::from_secs(0)) + > self.disconnected_timeout + self.failed_timeout + { + self.update_connection_state(ConnectionState::Failed).await; + *last_connection_state = self.connection_state.load(Ordering::SeqCst).into(); + return; + } + } + + self.contact_candidates().await; + + *last_connection_state = self.connection_state.load(Ordering::SeqCst).into(); + } + + async fn connectivity_checks(self: &Arc) { + const ZERO_DURATION: Duration = Duration::from_secs(0); + let mut last_connection_state = ConnectionState::Unspecified; + let mut checking_duration = Instant::now(); + let (check_interval, keepalive_interval, disconnected_timeout, failed_timeout) = ( + self.check_interval, + self.keepalive_interval, + self.disconnected_timeout, + self.failed_timeout, + ); + + let done_and_force_candidate_contact_rx = { + let mut done_and_force_candidate_contact_rx = + self.done_and_force_candidate_contact_rx.lock().await; + done_and_force_candidate_contact_rx.take() + }; + + if let Some((mut done_rx, mut force_candidate_contact_rx)) = + done_and_force_candidate_contact_rx + { + let ai = Arc::clone(self); + tokio::spawn(async move { + loop { + let mut interval = DEFAULT_CHECK_INTERVAL; + + let mut update_interval = |x: Duration| { + if x != ZERO_DURATION && (interval == ZERO_DURATION || interval > x) { + interval = x; + } + }; + + match last_connection_state { + ConnectionState::New | ConnectionState::Checking => { + // While connecting, check candidates more frequently + update_interval(check_interval); + } + ConnectionState::Connected | ConnectionState::Disconnected => { + update_interval(keepalive_interval); + } + _ => {} + }; + // Ensure we run our task loop as quickly as the minimum of our various configured timeouts + update_interval(disconnected_timeout); + update_interval(failed_timeout); + + let t = tokio::time::sleep(interval); + tokio::pin!(t); + + tokio::select! { + _ = t.as_mut() => { + ai.contact(&mut last_connection_state, &mut checking_duration).await; + }, + _ = force_candidate_contact_rx.recv() => { + ai.contact(&mut last_connection_state, &mut checking_duration).await; + }, + _ = done_rx.recv() => { + return; + } + } + } + }); + } + } + + pub(crate) async fn update_connection_state(&self, new_state: ConnectionState) { + if self.connection_state.load(Ordering::SeqCst) != new_state as u8 { + // Connection has gone to failed, release all gathered candidates + if new_state == ConnectionState::Failed { + self.delete_all_candidates().await; + } + + log::info!( + "[{}]: Setting new connection state: {}", + self.get_name(), + new_state + ); + self.connection_state + .store(new_state as u8, Ordering::SeqCst); + + // Call handler after finishing current task since we may be holding the agent lock + // and the handler may also require it + { + let chan_state_tx = self.chan_state_tx.lock().await; + if let Some(tx) = &*chan_state_tx { + let _ = tx.send(new_state).await; + } + } + } + } + + pub(crate) async fn set_selected_pair(&self, p: Option>) { + log::trace!( + "[{}]: Set selected candidate pair: {:?}", + self.get_name(), + p + ); + + if let Some(p) = p { + p.nominated.store(true, Ordering::SeqCst); + self.agent_conn.selected_pair.store(Some(p)); + + self.update_connection_state(ConnectionState::Connected) + .await; + + // Notify when the selected pair changes + { + let chan_candidate_pair_tx = self.chan_candidate_pair_tx.lock().await; + if let Some(tx) = &*chan_candidate_pair_tx { + let _ = tx.send(()).await; + } + } + + // Signal connected + { + let mut on_connected_tx = self.on_connected_tx.lock().await; + on_connected_tx.take(); + } + } else { + self.agent_conn.selected_pair.store(None); + } + } + + pub(crate) async fn ping_all_candidates(&self) { + log::trace!("[{}]: pinging all candidates", self.get_name(),); + + let mut pairs: Vec<( + Arc, + Arc, + )> = vec![]; + + { + let mut checklist = self.agent_conn.checklist.lock().await; + if checklist.is_empty() { + log::warn!( + "[{}]: pingAllCandidates called with no candidate pairs. Connection is not possible yet.", + self.get_name(), + ); + } + for p in &mut *checklist { + let p_state = p.state.load(Ordering::SeqCst); + if p_state == CandidatePairState::Waiting as u8 { + p.state + .store(CandidatePairState::InProgress as u8, Ordering::SeqCst); + } else if p_state != CandidatePairState::InProgress as u8 { + continue; + } + + if p.binding_request_count.load(Ordering::SeqCst) > self.max_binding_requests { + log::trace!( + "[{}]: max requests reached for pair {}, marking it as failed", + self.get_name(), + p + ); + p.state + .store(CandidatePairState::Failed as u8, Ordering::SeqCst); + } else { + p.binding_request_count.fetch_add(1, Ordering::SeqCst); + let local = p.local.clone(); + let remote = p.remote.clone(); + pairs.push((local, remote)); + } + } + } + + for (local, remote) in pairs { + self.ping_candidate(&local, &remote).await; + } + } + + pub(crate) async fn add_pair( + &self, + local: Arc, + remote: Arc, + ) { + let p = Arc::new(CandidatePair::new( + local, + remote, + self.is_controlling.load(Ordering::SeqCst), + )); + let mut checklist = self.agent_conn.checklist.lock().await; + checklist.push(p); + } + + pub(crate) async fn find_pair( + &self, + local: &Arc, + remote: &Arc, + ) -> Option> { + let checklist = self.agent_conn.checklist.lock().await; + for p in &*checklist { + if p.local.equal(&**local) && p.remote.equal(&**remote) { + return Some(p.clone()); + } + } + None + } + + /// Checks if the selected pair is (still) valid. + /// Note: the caller should hold the agent lock. + pub(crate) async fn validate_selected_pair(&self) -> bool { + let (valid, disconnected_time) = { + let selected_pair = self.agent_conn.selected_pair.load(); + (*selected_pair).as_ref().map_or_else( + || (false, Duration::from_secs(0)), + |selected_pair| { + let disconnected_time = SystemTime::now() + .duration_since(selected_pair.remote.last_received()) + .unwrap_or_else(|_| Duration::from_secs(0)); + (true, disconnected_time) + }, + ) + }; + + if valid { + // Only allow transitions to failed if a.failedTimeout is non-zero + let mut total_time_to_failure = self.failed_timeout; + if total_time_to_failure != Duration::from_secs(0) { + total_time_to_failure += self.disconnected_timeout; + } + + if total_time_to_failure != Duration::from_secs(0) + && disconnected_time > total_time_to_failure + { + self.update_connection_state(ConnectionState::Failed).await; + } else if self.disconnected_timeout != Duration::from_secs(0) + && disconnected_time > self.disconnected_timeout + { + self.update_connection_state(ConnectionState::Disconnected) + .await; + } else { + self.update_connection_state(ConnectionState::Connected) + .await; + } + } + + valid + } + + /// Sends STUN Binding Indications to the selected pair. + /// if no packet has been sent on that pair in the last keepaliveInterval. + /// Note: the caller should hold the agent lock. + pub(crate) async fn check_keepalive(&self) { + let (local, remote) = { + let selected_pair = self.agent_conn.selected_pair.load(); + (*selected_pair) + .as_ref() + .map_or((None, None), |selected_pair| { + ( + Some(selected_pair.local.clone()), + Some(selected_pair.remote.clone()), + ) + }) + }; + + if let (Some(local), Some(remote)) = (local, remote) { + let last_sent = SystemTime::now() + .duration_since(local.last_sent()) + .unwrap_or_else(|_| Duration::from_secs(0)); + + let last_received = SystemTime::now() + .duration_since(remote.last_received()) + .unwrap_or_else(|_| Duration::from_secs(0)); + + if (self.keepalive_interval != Duration::from_secs(0)) + && ((last_sent > self.keepalive_interval) + || (last_received > self.keepalive_interval)) + { + // we use binding request instead of indication to support refresh consent schemas + // see https://tools.ietf.org/html/rfc7675 + self.ping_candidate(&local, &remote).await; + } + } + } + + fn request_connectivity_check(&self) { + let _ = self.force_candidate_contact_tx.try_send(true); + } + + /// Assumes you are holding the lock (must be execute using a.run). + pub(crate) async fn add_remote_candidate(&self, c: &Arc) { + let network_type = c.network_type(); + + { + let mut remote_candidates = self.remote_candidates.lock().await; + if let Some(cands) = remote_candidates.get(&network_type) { + for cand in cands { + if cand.equal(&**c) { + return; + } + } + } + + if let Some(cands) = remote_candidates.get_mut(&network_type) { + cands.push(c.clone()); + } else { + remote_candidates.insert(network_type, vec![c.clone()]); + } + } + + let mut local_cands = vec![]; + { + let local_candidates = self.local_candidates.lock().await; + if let Some(cands) = local_candidates.get(&network_type) { + local_cands = cands.clone(); + } + } + + for cand in local_cands { + self.add_pair(cand, c.clone()).await; + } + + self.request_connectivity_check(); + } + + pub(crate) async fn add_candidate( + self: &Arc, + c: &Arc, + ) -> Result<()> { + let initialized_ch = { + let started_ch_tx = self.started_ch_tx.lock().await; + (*started_ch_tx).as_ref().map(|tx| tx.subscribe()) + }; + + self.start_candidate(c, initialized_ch).await; + + let network_type = c.network_type(); + { + let mut local_candidates = self.local_candidates.lock().await; + if let Some(cands) = local_candidates.get(&network_type) { + for cand in cands { + if cand.equal(&**c) { + if let Err(err) = c.close().await { + log::warn!( + "[{}]: Failed to close duplicate candidate: {}", + self.get_name(), + err + ); + } + //TODO: why return? + return Ok(()); + } + } + } + + if let Some(cands) = local_candidates.get_mut(&network_type) { + cands.push(c.clone()); + } else { + local_candidates.insert(network_type, vec![c.clone()]); + } + } + + let mut remote_cands = vec![]; + { + let remote_candidates = self.remote_candidates.lock().await; + if let Some(cands) = remote_candidates.get(&network_type) { + remote_cands = cands.clone(); + } + } + + for cand in remote_cands { + self.add_pair(c.clone(), cand).await; + } + + self.request_connectivity_check(); + { + let chan_candidate_tx = self.chan_candidate_tx.lock().await; + if let Some(tx) = &*chan_candidate_tx { + let _ = tx.send(Some(c.clone())).await; + } + } + + Ok(()) + } + + pub(crate) async fn close(&self) -> Result<()> { + { + let mut done_tx = self.done_tx.lock().await; + if done_tx.is_none() { + return Err(Error::ErrClosed); + } + done_tx.take(); + }; + self.delete_all_candidates().await; + { + let mut started_ch_tx = self.started_ch_tx.lock().await; + started_ch_tx.take(); + } + + self.agent_conn.buffer.close().await; + + self.update_connection_state(ConnectionState::Closed).await; + + { + let mut chan_candidate_tx = self.chan_candidate_tx.lock().await; + chan_candidate_tx.take(); + } + { + let mut chan_candidate_pair_tx = self.chan_candidate_pair_tx.lock().await; + chan_candidate_pair_tx.take(); + } + { + let mut chan_state_tx = self.chan_state_tx.lock().await; + chan_state_tx.take(); + } + + self.agent_conn.done.store(true, Ordering::SeqCst); + + Ok(()) + } + + /// Remove all candidates. + /// This closes any listening sockets and removes both the local and remote candidate lists. + /// + /// This is used for restarts, failures and on close. + pub(crate) async fn delete_all_candidates(&self) { + { + let mut local_candidates = self.local_candidates.lock().await; + for cs in local_candidates.values_mut() { + for c in cs { + if let Err(err) = c.close().await { + log::warn!( + "[{}]: Failed to close candidate {}: {}", + self.get_name(), + c, + err + ); + } + } + } + local_candidates.clear(); + } + + { + let mut remote_candidates = self.remote_candidates.lock().await; + for cs in remote_candidates.values_mut() { + for c in cs { + if let Err(err) = c.close().await { + log::warn!( + "[{}]: Failed to close candidate {}: {}", + self.get_name(), + c, + err + ); + } + } + } + remote_candidates.clear(); + } + } + + pub(crate) async fn find_remote_candidate( + &self, + network_type: NetworkType, + addr: SocketAddr, + ) -> Option> { + let (ip, port) = (addr.ip(), addr.port()); + + let remote_candidates = self.remote_candidates.lock().await; + if let Some(cands) = remote_candidates.get(&network_type) { + for c in cands { + if c.address() == ip.to_string() && c.port() == port { + return Some(c.clone()); + } + } + } + None + } + + pub(crate) async fn send_binding_request( + &self, + m: &Message, + local: &Arc, + remote: &Arc, + ) { + log::trace!( + "[{}]: ping STUN from {} to {}", + self.get_name(), + local, + remote + ); + + self.invalidate_pending_binding_requests(Instant::now()) + .await; + { + let mut pending_binding_requests = self.pending_binding_requests.lock().await; + pending_binding_requests.push(BindingRequest { + timestamp: Instant::now(), + transaction_id: m.transaction_id, + destination: remote.addr(), + is_use_candidate: m.contains(ATTR_USE_CANDIDATE), + }); + } + + self.send_stun(m, local, remote).await; + } + + pub(crate) async fn send_binding_success( + &self, + m: &Message, + local: &Arc, + remote: &Arc, + ) { + let addr = remote.addr(); + let (ip, port) = (addr.ip(), addr.port()); + let local_pwd = { + let ufrag_pwd = self.ufrag_pwd.lock().await; + ufrag_pwd.local_pwd.clone() + }; + + let (out, result) = { + let mut out = Message::new(); + let result = out.build(&[ + Box::new(m.clone()), + Box::new(BINDING_SUCCESS), + Box::new(XorMappedAddress { ip, port }), + Box::new(MessageIntegrity::new_short_term_integrity(local_pwd)), + Box::new(FINGERPRINT), + ]); + (out, result) + }; + + if let Err(err) = result { + log::warn!( + "[{}]: Failed to handle inbound ICE from: {} to: {} error: {}", + self.get_name(), + local, + remote, + err + ); + } else { + self.send_stun(&out, local, remote).await; + } + } + + /// Removes pending binding requests that are over `maxBindingRequestTimeout` old Let HTO be the + /// transaction timeout, which SHOULD be 2*RTT if RTT is known or 500 ms otherwise. + /// + /// reference: (IETF ref-8445)[https://tools.ietf.org/html/rfc8445#appendix-B.1]. + pub(crate) async fn invalidate_pending_binding_requests(&self, filter_time: Instant) { + let mut pending_binding_requests = self.pending_binding_requests.lock().await; + let initial_size = pending_binding_requests.len(); + + let mut temp = vec![]; + for binding_request in pending_binding_requests.drain(..) { + if filter_time + .checked_duration_since(binding_request.timestamp) + .map(|duration| duration < MAX_BINDING_REQUEST_TIMEOUT) + .unwrap_or(true) + { + temp.push(binding_request); + } + } + + *pending_binding_requests = temp; + let bind_requests_removed = initial_size - pending_binding_requests.len(); + if bind_requests_removed > 0 { + log::trace!( + "[{}]: Discarded {} binding requests because they expired", + self.get_name(), + bind_requests_removed + ); + } + } + + /// Assert that the passed `TransactionID` is in our `pendingBindingRequests` and returns the + /// destination, If the bindingRequest was valid remove it from our pending cache. + pub(crate) async fn handle_inbound_binding_success( + &self, + id: TransactionId, + ) -> Option { + self.invalidate_pending_binding_requests(Instant::now()) + .await; + + let mut pending_binding_requests = self.pending_binding_requests.lock().await; + for i in 0..pending_binding_requests.len() { + if pending_binding_requests[i].transaction_id == id { + let valid_binding_request = pending_binding_requests.remove(i); + return Some(valid_binding_request); + } + } + None + } + + /// Processes STUN traffic from a remote candidate. + pub(crate) async fn handle_inbound( + &self, + m: &mut Message, + local: &Arc, + remote: SocketAddr, + ) { + if m.typ.method != METHOD_BINDING + || !(m.typ.class == CLASS_SUCCESS_RESPONSE + || m.typ.class == CLASS_REQUEST + || m.typ.class == CLASS_INDICATION) + { + log::trace!( + "[{}]: unhandled STUN from {} to {} class({}) method({})", + self.get_name(), + remote, + local, + m.typ.class, + m.typ.method + ); + return; + } + + if self.is_controlling.load(Ordering::SeqCst) { + if m.contains(ATTR_ICE_CONTROLLING) { + log::debug!( + "[{}]: inbound isControlling && a.isControlling == true", + self.get_name(), + ); + return; + } else if m.contains(ATTR_USE_CANDIDATE) { + log::debug!( + "[{}]: useCandidate && a.isControlling == true", + self.get_name(), + ); + return; + } + } else if m.contains(ATTR_ICE_CONTROLLED) { + log::debug!( + "[{}]: inbound isControlled && a.isControlling == false", + self.get_name(), + ); + return; + } + + let mut remote_candidate = self + .find_remote_candidate(local.network_type(), remote) + .await; + if m.typ.class == CLASS_SUCCESS_RESPONSE { + { + let ufrag_pwd = self.ufrag_pwd.lock().await; + if let Err(err) = + assert_inbound_message_integrity(m, ufrag_pwd.remote_pwd.as_bytes()) + { + log::warn!( + "[{}]: discard message from ({}), {}", + self.get_name(), + remote, + err + ); + return; + } + } + + if let Some(rc) = &remote_candidate { + self.handle_success_response(m, local, rc, remote).await; + } else { + log::warn!( + "[{}]: discard success message from ({}), no such remote", + self.get_name(), + remote + ); + return; + } + } else if m.typ.class == CLASS_REQUEST { + { + let ufrag_pwd = self.ufrag_pwd.lock().await; + let username = + ufrag_pwd.local_ufrag.clone() + ":" + ufrag_pwd.remote_ufrag.as_str(); + if let Err(err) = assert_inbound_username(m, &username) { + log::warn!( + "[{}]: discard message from ({}), {}", + self.get_name(), + remote, + err + ); + return; + } else if let Err(err) = + assert_inbound_message_integrity(m, ufrag_pwd.local_pwd.as_bytes()) + { + log::warn!( + "[{}]: discard message from ({}), {}", + self.get_name(), + remote, + err + ); + return; + } + } + + if remote_candidate.is_none() { + let (ip, port, network_type) = (remote.ip(), remote.port(), NetworkType::Udp4); + + let prflx_candidate_config = CandidatePeerReflexiveConfig { + base_config: CandidateBaseConfig { + network: network_type.to_string(), + address: ip.to_string(), + port, + component: local.component(), + ..CandidateBaseConfig::default() + }, + rel_addr: "".to_owned(), + rel_port: 0, + }; + + match prflx_candidate_config.new_candidate_peer_reflexive() { + Ok(prflx_candidate) => remote_candidate = Some(Arc::new(prflx_candidate)), + Err(err) => { + log::error!( + "[{}]: Failed to create new remote prflx candidate ({})", + self.get_name(), + err + ); + return; + } + }; + + log::debug!( + "[{}]: adding a new peer-reflexive candidate: {} ", + self.get_name(), + remote + ); + if let Some(rc) = &remote_candidate { + self.add_remote_candidate(rc).await; + } + } + + log::trace!( + "[{}]: inbound STUN (Request) from {} to {}", + self.get_name(), + remote, + local + ); + + if let Some(rc) = &remote_candidate { + self.handle_binding_request(m, local, rc).await; + } + } + + if let Some(rc) = remote_candidate { + rc.seen(false); + } + } + + /// Processes non STUN traffic from a remote candidate, and returns true if it is an actual + /// remote candidate. + pub(crate) async fn validate_non_stun_traffic( + &self, + local: &Arc, + remote: SocketAddr, + ) -> bool { + self.find_remote_candidate(local.network_type(), remote) + .await + .map_or(false, |remote_candidate| { + remote_candidate.seen(false); + true + }) + } + + /// Sets the credentials of the remote agent. + pub(crate) async fn set_remote_credentials( + &self, + remote_ufrag: String, + remote_pwd: String, + ) -> Result<()> { + if remote_ufrag.is_empty() { + return Err(Error::ErrRemoteUfragEmpty); + } else if remote_pwd.is_empty() { + return Err(Error::ErrRemotePwdEmpty); + } + + let mut ufrag_pwd = self.ufrag_pwd.lock().await; + ufrag_pwd.remote_ufrag = remote_ufrag; + ufrag_pwd.remote_pwd = remote_pwd; + Ok(()) + } + + pub(crate) async fn send_stun( + &self, + msg: &Message, + local: &Arc, + remote: &Arc, + ) { + if let Err(err) = local.write_to(&msg.raw, &**remote).await { + log::trace!( + "[{}]: failed to send STUN message: {}", + self.get_name(), + err + ); + } + } + + /// Runs the candidate using the provided connection. + async fn start_candidate( + self: &Arc, + candidate: &Arc, + initialized_ch: Option>, + ) { + let (closed_ch_tx, closed_ch_rx) = broadcast::channel(1); + { + let closed_ch = candidate.get_closed_ch(); + let mut closed = closed_ch.lock().await; + *closed = Some(closed_ch_tx); + } + + let cand = Arc::clone(candidate); + if let Some(conn) = candidate.get_conn() { + let conn = Arc::clone(conn); + let addr = candidate.addr(); + let ai = Arc::clone(self); + tokio::spawn(async move { + let _ = ai + .recv_loop(cand, closed_ch_rx, initialized_ch, conn, addr) + .await; + }); + } else { + log::error!("[{}]: Can't start due to conn is_none", self.get_name(),); + } + } + + pub(super) fn start_on_connection_state_change_routine( + self: &Arc, + mut chan_state_rx: mpsc::Receiver, + mut chan_candidate_rx: mpsc::Receiver>>, + mut chan_candidate_pair_rx: mpsc::Receiver<()>, + ) { + let ai = Arc::clone(self); + tokio::spawn(async move { + // CandidatePair and ConnectionState are usually changed at once. + // Blocking one by the other one causes deadlock. + while chan_candidate_pair_rx.recv().await.is_some() { + if let (Some(cb), Some(p)) = ( + &*ai.on_selected_candidate_pair_change_hdlr.load(), + &*ai.agent_conn.selected_pair.load(), + ) { + let mut f = cb.lock().await; + f(&p.local, &p.remote).await; + } + } + }); + + let ai = Arc::clone(self); + tokio::spawn(async move { + loop { + tokio::select! { + opt_state = chan_state_rx.recv() => { + if let Some(s) = opt_state { + if let Some(handler) = &*ai.on_connection_state_change_hdlr.load() { + let mut f = handler.lock().await; + f(s).await; + } + } else { + while let Some(c) = chan_candidate_rx.recv().await { + if let Some(handler) = &*ai.on_candidate_hdlr.load() { + let mut f = handler.lock().await; + f(c).await; + } + } + break; + } + }, + opt_cand = chan_candidate_rx.recv() => { + if let Some(c) = opt_cand { + if let Some(handler) = &*ai.on_candidate_hdlr.load() { + let mut f = handler.lock().await; + f(c).await; + } + } else { + while let Some(s) = chan_state_rx.recv().await { + if let Some(handler) = &*ai.on_connection_state_change_hdlr.load() { + let mut f = handler.lock().await; + f(s).await; + } + } + break; + } + } + } + } + }); + } + + async fn recv_loop( + self: &Arc, + candidate: Arc, + mut closed_ch_rx: broadcast::Receiver<()>, + initialized_ch: Option>, + conn: Arc, + addr: SocketAddr, + ) -> Result<()> { + if let Some(mut initialized_ch) = initialized_ch { + tokio::select! { + _ = initialized_ch.recv() => {} + _ = closed_ch_rx.recv() => return Err(Error::ErrClosed), + } + } + + let mut buffer = vec![0_u8; RECEIVE_MTU]; + let mut n; + let mut src_addr; + loop { + tokio::select! { + result = conn.recv_from(&mut buffer) => { + match result { + Ok((num, src)) => { + n = num; + src_addr = src; + } + Err(err) => return Err(Error::Other(err.to_string())), + } + }, + _ = closed_ch_rx.recv() => return Err(Error::ErrClosed), + } + + self.handle_inbound_candidate_msg(&candidate, &buffer[..n], src_addr, addr) + .await; + } + } + + async fn handle_inbound_candidate_msg( + self: &Arc, + c: &Arc, + buf: &[u8], + src_addr: SocketAddr, + addr: SocketAddr, + ) { + if stun::message::is_message(buf) { + let mut m = Message { + raw: vec![], + ..Message::default() + }; + // Explicitly copy raw buffer so Message can own the memory. + m.raw.extend_from_slice(buf); + + if let Err(err) = m.decode() { + log::warn!( + "[{}]: Failed to handle decode ICE from {} to {}: {}", + self.get_name(), + addr, + src_addr, + err + ); + } else { + self.handle_inbound(&mut m, c, src_addr).await; + } + } else if !self.validate_non_stun_traffic(c, src_addr).await { + log::warn!( + "[{}]: Discarded message, not a valid remote candidate", + self.get_name(), + //c.addr().await //from {} + ); + } else if let Err(err) = self.agent_conn.buffer.write(buf).await { + // NOTE This will return packetio.ErrFull if the buffer ever manages to fill up. + log::warn!("[{}]: failed to write packet: {}", self.get_name(), err); + } + } + + pub(crate) fn get_name(&self) -> &str { + if self.is_controlling.load(Ordering::SeqCst) { + "controlling" + } else { + "controlled" + } + } +} diff --git a/reserved/ice/src/agent/agent_selector.rs b/reserved/ice/src/agent/agent_selector.rs new file mode 100644 index 0000000..b7e05fc --- /dev/null +++ b/reserved/ice/src/agent/agent_selector.rs @@ -0,0 +1,545 @@ +use std::net::SocketAddr; +use std::sync::atomic::Ordering; +use std::sync::Arc; + +use async_trait::async_trait; +use stun::agent::*; +use stun::attributes::*; +use stun::fingerprint::*; +use stun::integrity::*; +use stun::message::*; +use stun::textattrs::*; +use tokio::time::{Duration, Instant}; + +use crate::agent::agent_internal::*; +use crate::candidate::*; +use crate::control::*; +use crate::priority::*; +use crate::use_candidate::*; + +#[async_trait] +trait ControllingSelector { + async fn start(&self); + async fn contact_candidates(&self); + async fn ping_candidate( + &self, + local: &Arc, + remote: &Arc, + ); + async fn handle_success_response( + &self, + m: &Message, + local: &Arc, + remote: &Arc, + remote_addr: SocketAddr, + ); + async fn handle_binding_request( + &self, + m: &Message, + local: &Arc, + remote: &Arc, + ); +} + +#[async_trait] +trait ControlledSelector { + async fn start(&self); + async fn contact_candidates(&self); + async fn ping_candidate( + &self, + local: &Arc, + remote: &Arc, + ); + async fn handle_success_response( + &self, + m: &Message, + local: &Arc, + remote: &Arc, + remote_addr: SocketAddr, + ); + async fn handle_binding_request( + &self, + m: &Message, + local: &Arc, + remote: &Arc, + ); +} + +impl AgentInternal { + fn is_nominatable(&self, c: &Arc) -> bool { + let start_time = *self.start_time.lock(); + match c.candidate_type() { + CandidateType::Host => { + Instant::now() + .checked_duration_since(start_time) + .unwrap_or_else(|| Duration::from_secs(0)) + .as_nanos() + > self.host_acceptance_min_wait.as_nanos() + } + CandidateType::ServerReflexive => { + Instant::now() + .checked_duration_since(start_time) + .unwrap_or_else(|| Duration::from_secs(0)) + .as_nanos() + > self.srflx_acceptance_min_wait.as_nanos() + } + CandidateType::PeerReflexive => { + Instant::now() + .checked_duration_since(start_time) + .unwrap_or_else(|| Duration::from_secs(0)) + .as_nanos() + > self.prflx_acceptance_min_wait.as_nanos() + } + CandidateType::Relay => { + Instant::now() + .checked_duration_since(start_time) + .unwrap_or_else(|| Duration::from_secs(0)) + .as_nanos() + > self.relay_acceptance_min_wait.as_nanos() + } + CandidateType::Unspecified => { + log::error!( + "is_nominatable invalid candidate type {}", + c.candidate_type() + ); + false + } + } + } + + async fn nominate_pair(&self) { + let result = { + let nominated_pair = self.nominated_pair.lock().await; + if let Some(pair) = &*nominated_pair { + // The controlling agent MUST include the USE-CANDIDATE attribute in + // order to nominate a candidate pair (Section 8.1.1). The controlled + // agent MUST NOT include the USE-CANDIDATE attribute in a Binding + // request. + + let (msg, result) = { + let ufrag_pwd = self.ufrag_pwd.lock().await; + let username = + ufrag_pwd.remote_ufrag.clone() + ":" + ufrag_pwd.local_ufrag.as_str(); + let mut msg = Message::new(); + let result = msg.build(&[ + Box::new(BINDING_REQUEST), + Box::new(TransactionId::new()), + Box::new(Username::new(ATTR_USERNAME, username)), + Box::::default(), + Box::new(AttrControlling(self.tie_breaker.load(Ordering::SeqCst))), + Box::new(PriorityAttr(pair.local.priority())), + Box::new(MessageIntegrity::new_short_term_integrity( + ufrag_pwd.remote_pwd.clone(), + )), + Box::new(FINGERPRINT), + ]); + (msg, result) + }; + + if let Err(err) = result { + log::error!("{}", err); + None + } else { + log::trace!( + "ping STUN (nominate candidate pair from {} to {}", + pair.local, + pair.remote + ); + let local = pair.local.clone(); + let remote = pair.remote.clone(); + Some((msg, local, remote)) + } + } else { + None + } + }; + + if let Some((msg, local, remote)) = result { + self.send_binding_request(&msg, &local, &remote).await; + } + } + + pub(crate) async fn start(&self) { + if self.is_controlling.load(Ordering::SeqCst) { + ControllingSelector::start(self).await; + } else { + ControlledSelector::start(self).await; + } + } + + pub(crate) async fn contact_candidates(&self) { + if self.is_controlling.load(Ordering::SeqCst) { + ControllingSelector::contact_candidates(self).await; + } else { + ControlledSelector::contact_candidates(self).await; + } + } + + pub(crate) async fn ping_candidate( + &self, + local: &Arc, + remote: &Arc, + ) { + if self.is_controlling.load(Ordering::SeqCst) { + ControllingSelector::ping_candidate(self, local, remote).await; + } else { + ControlledSelector::ping_candidate(self, local, remote).await; + } + } + + pub(crate) async fn handle_success_response( + &self, + m: &Message, + local: &Arc, + remote: &Arc, + remote_addr: SocketAddr, + ) { + if self.is_controlling.load(Ordering::SeqCst) { + ControllingSelector::handle_success_response(self, m, local, remote, remote_addr).await; + } else { + ControlledSelector::handle_success_response(self, m, local, remote, remote_addr).await; + } + } + + pub(crate) async fn handle_binding_request( + &self, + m: &Message, + local: &Arc, + remote: &Arc, + ) { + if self.is_controlling.load(Ordering::SeqCst) { + ControllingSelector::handle_binding_request(self, m, local, remote).await; + } else { + ControlledSelector::handle_binding_request(self, m, local, remote).await; + } + } +} + +#[async_trait] +impl ControllingSelector for AgentInternal { + async fn start(&self) { + { + let mut nominated_pair = self.nominated_pair.lock().await; + *nominated_pair = None; + } + *self.start_time.lock() = Instant::now(); + } + + async fn contact_candidates(&self) { + // A lite selector should not contact candidates + if self.lite.load(Ordering::SeqCst) { + // This only happens if both peers are lite. See RFC 8445 S6.1.1 and S6.2 + log::trace!("now falling back to full agent"); + } + + let nominated_pair_is_some = { + let nominated_pair = self.nominated_pair.lock().await; + nominated_pair.is_some() + }; + + if self.agent_conn.get_selected_pair().is_some() { + if self.validate_selected_pair().await { + log::trace!("[{}]: checking keepalive", self.get_name()); + self.check_keepalive().await; + } + } else if nominated_pair_is_some { + self.nominate_pair().await; + } else { + let has_nominated_pair = + if let Some(p) = self.agent_conn.get_best_valid_candidate_pair().await { + self.is_nominatable(&p.local) && self.is_nominatable(&p.remote) + } else { + false + }; + + if has_nominated_pair { + if let Some(p) = self.agent_conn.get_best_valid_candidate_pair().await { + log::trace!( + "Nominatable pair found, nominating ({}, {})", + p.local.to_string(), + p.remote.to_string() + ); + p.nominated.store(true, Ordering::SeqCst); + { + let mut nominated_pair = self.nominated_pair.lock().await; + *nominated_pair = Some(p); + } + } + + self.nominate_pair().await; + } else { + self.ping_all_candidates().await; + } + } + } + + async fn ping_candidate( + &self, + local: &Arc, + remote: &Arc, + ) { + let (msg, result) = { + let ufrag_pwd = self.ufrag_pwd.lock().await; + let username = ufrag_pwd.remote_ufrag.clone() + ":" + ufrag_pwd.local_ufrag.as_str(); + let mut msg = Message::new(); + let result = msg.build(&[ + Box::new(BINDING_REQUEST), + Box::new(TransactionId::new()), + Box::new(Username::new(ATTR_USERNAME, username)), + Box::new(AttrControlling(self.tie_breaker.load(Ordering::SeqCst))), + Box::new(PriorityAttr(local.priority())), + Box::new(MessageIntegrity::new_short_term_integrity( + ufrag_pwd.remote_pwd.clone(), + )), + Box::new(FINGERPRINT), + ]); + (msg, result) + }; + + if let Err(err) = result { + log::error!("{}", err); + } else { + self.send_binding_request(&msg, local, remote).await; + } + } + + async fn handle_success_response( + &self, + m: &Message, + local: &Arc, + remote: &Arc, + remote_addr: SocketAddr, + ) { + if let Some(pending_request) = self.handle_inbound_binding_success(m.transaction_id).await { + let transaction_addr = pending_request.destination; + + // Assert that NAT is not symmetric + // https://tools.ietf.org/html/rfc8445#section-7.2.5.2.1 + if transaction_addr != remote_addr { + log::debug!("discard message: transaction source and destination does not match expected({}), actual({})", transaction_addr, remote); + return; + } + + log::trace!( + "inbound STUN (SuccessResponse) from {} to {}", + remote, + local + ); + let selected_pair_is_none = self.agent_conn.get_selected_pair().is_none(); + + if let Some(p) = self.find_pair(local, remote).await { + p.state + .store(CandidatePairState::Succeeded as u8, Ordering::SeqCst); + log::trace!( + "Found valid candidate pair: {}, p.state: {}, isUseCandidate: {}, {}", + p, + p.state.load(Ordering::SeqCst), + pending_request.is_use_candidate, + selected_pair_is_none + ); + if pending_request.is_use_candidate && selected_pair_is_none { + self.set_selected_pair(Some(Arc::clone(&p))).await; + } + } else { + // This shouldn't happen + log::error!("Success response from invalid candidate pair"); + } + } else { + log::warn!( + "discard message from ({}), unknown TransactionID 0x{:?}", + remote, + m.transaction_id + ); + } + } + + async fn handle_binding_request( + &self, + m: &Message, + local: &Arc, + remote: &Arc, + ) { + self.send_binding_success(m, local, remote).await; + log::trace!("controllingSelector: sendBindingSuccess"); + + if let Some(p) = self.find_pair(local, remote).await { + let nominated_pair_is_none = { + let nominated_pair = self.nominated_pair.lock().await; + nominated_pair.is_none() + }; + + log::trace!( + "controllingSelector: after findPair {}, p.state: {}, {}", + p, + p.state.load(Ordering::SeqCst), + nominated_pair_is_none, + //self.agent_conn.get_selected_pair().await.is_none() //, {} + ); + if p.state.load(Ordering::SeqCst) == CandidatePairState::Succeeded as u8 + && nominated_pair_is_none + && self.agent_conn.get_selected_pair().is_none() + { + if let Some(best_pair) = self.agent_conn.get_best_available_candidate_pair().await { + log::trace!( + "controllingSelector: getBestAvailableCandidatePair {}", + best_pair + ); + if best_pair == p + && self.is_nominatable(&p.local) + && self.is_nominatable(&p.remote) + { + log::trace!("The candidate ({}, {}) is the best candidate available, marking it as nominated", + p.local, p.remote); + { + let mut nominated_pair = self.nominated_pair.lock().await; + *nominated_pair = Some(p); + } + self.nominate_pair().await; + } + } else { + log::trace!("No best pair available"); + } + } + } else { + log::trace!("controllingSelector: addPair"); + self.add_pair(local.clone(), remote.clone()).await; + } + } +} + +#[async_trait] +impl ControlledSelector for AgentInternal { + async fn start(&self) {} + + async fn contact_candidates(&self) { + // A lite selector should not contact candidates + if self.lite.load(Ordering::SeqCst) { + self.validate_selected_pair().await; + } else if self.agent_conn.get_selected_pair().is_some() { + if self.validate_selected_pair().await { + log::trace!("[{}]: checking keepalive", self.get_name()); + self.check_keepalive().await; + } + } else { + self.ping_all_candidates().await; + } + } + + async fn ping_candidate( + &self, + local: &Arc, + remote: &Arc, + ) { + let (msg, result) = { + let ufrag_pwd = self.ufrag_pwd.lock().await; + let username = ufrag_pwd.remote_ufrag.clone() + ":" + ufrag_pwd.local_ufrag.as_str(); + let mut msg = Message::new(); + let result = msg.build(&[ + Box::new(BINDING_REQUEST), + Box::new(TransactionId::new()), + Box::new(Username::new(ATTR_USERNAME, username)), + Box::new(AttrControlled(self.tie_breaker.load(Ordering::SeqCst))), + Box::new(PriorityAttr(local.priority())), + Box::new(MessageIntegrity::new_short_term_integrity( + ufrag_pwd.remote_pwd.clone(), + )), + Box::new(FINGERPRINT), + ]); + (msg, result) + }; + + if let Err(err) = result { + log::error!("{}", err); + } else { + self.send_binding_request(&msg, local, remote).await; + } + } + + async fn handle_success_response( + &self, + m: &Message, + local: &Arc, + remote: &Arc, + remote_addr: SocketAddr, + ) { + // https://tools.ietf.org/html/rfc8445#section-7.3.1.5 + // If the controlled agent does not accept the request from the + // controlling agent, the controlled agent MUST reject the nomination + // request with an appropriate error code response (e.g., 400) + // [RFC5389]. + + if let Some(pending_request) = self.handle_inbound_binding_success(m.transaction_id).await { + let transaction_addr = pending_request.destination; + + // Assert that NAT is not symmetric + // https://tools.ietf.org/html/rfc8445#section-7.2.5.2.1 + if transaction_addr != remote_addr { + log::debug!("discard message: transaction source and destination does not match expected({}), actual({})", transaction_addr, remote); + return; + } + + log::trace!( + "inbound STUN (SuccessResponse) from {} to {}", + remote, + local + ); + + if let Some(p) = self.find_pair(local, remote).await { + p.state + .store(CandidatePairState::Succeeded as u8, Ordering::SeqCst); + log::trace!("Found valid candidate pair: {}", p); + } else { + // This shouldn't happen + log::error!("Success response from invalid candidate pair"); + } + } else { + log::warn!( + "discard message from ({}), unknown TransactionID 0x{:?}", + remote, + m.transaction_id + ); + } + } + + async fn handle_binding_request( + &self, + m: &Message, + local: &Arc, + remote: &Arc, + ) { + if self.find_pair(local, remote).await.is_none() { + self.add_pair(local.clone(), remote.clone()).await; + } + + if let Some(p) = self.find_pair(local, remote).await { + let use_candidate = m.contains(ATTR_USE_CANDIDATE); + if use_candidate { + // https://tools.ietf.org/html/rfc8445#section-7.3.1.5 + + if p.state.load(Ordering::SeqCst) == CandidatePairState::Succeeded as u8 { + // If the state of this pair is Succeeded, it means that the check + // previously sent by this pair produced a successful response and + // generated a valid pair (Section 7.2.5.3.2). The agent sets the + // nominated flag value of the valid pair to true. + if self.agent_conn.get_selected_pair().is_none() { + self.set_selected_pair(Some(Arc::clone(&p))).await; + } + self.send_binding_success(m, local, remote).await; + } else { + // If the received Binding request triggered a new check to be + // enqueued in the triggered-check queue (Section 7.3.1.4), once the + // check is sent and if it generates a successful response, and + // generates a valid pair, the agent sets the nominated flag of the + // pair to true. If the request fails (Section 7.2.5.2), the agent + // MUST remove the candidate pair from the valid list, set the + // candidate pair state to Failed, and set the checklist state to + // Failed. + self.ping_candidate(local, remote).await; + } + } else { + self.send_binding_success(m, local, remote).await; + self.ping_candidate(local, remote).await; + } + } + } +} diff --git a/reserved/ice/src/agent/agent_stats.rs b/reserved/ice/src/agent/agent_stats.rs new file mode 100644 index 0000000..27ad3cc --- /dev/null +++ b/reserved/ice/src/agent/agent_stats.rs @@ -0,0 +1,283 @@ +use std::sync::atomic::Ordering; + +use tokio::time::Instant; + +use crate::agent::agent_internal::AgentInternal; +use crate::candidate::{CandidatePairState, CandidateType}; +use crate::network_type::NetworkType; + +/// Contains ICE candidate pair statistics. +pub struct CandidatePairStats { + /// The timestamp associated with this struct. + pub timestamp: Instant, + + /// The id of the local candidate. + pub local_candidate_id: String, + + /// The id of the remote candidate. + pub remote_candidate_id: String, + + /// The state of the checklist for the local and remote candidates in a pair. + pub state: CandidatePairState, + + /// It is true when this valid pair that should be used for media, + /// if it is the highest-priority one amongst those whose nominated flag is set. + pub nominated: bool, + + /// The total number of packets sent on this candidate pair. + pub packets_sent: u32, + + /// The total number of packets received on this candidate pair. + pub packets_received: u32, + + /// The total number of payload bytes sent on this candidate pair not including headers or + /// padding. + pub bytes_sent: u64, + + /// The total number of payload bytes received on this candidate pair not including headers or + /// padding. + pub bytes_received: u64, + + /// The timestamp at which the last packet was sent on this particular candidate pair, excluding + /// STUN packets. + pub last_packet_sent_timestamp: Instant, + + /// The timestamp at which the last packet was received on this particular candidate pair, + /// excluding STUN packets. + pub last_packet_received_timestamp: Instant, + + /// The timestamp at which the first STUN request was sent on this particular candidate pair. + pub first_request_timestamp: Instant, + + /// The timestamp at which the last STUN request was sent on this particular candidate pair. + /// The average interval between two consecutive connectivity checks sent can be calculated with + /// (last_request_timestamp - first_request_timestamp) / requests_sent. + pub last_request_timestamp: Instant, + + /// Timestamp at which the last STUN response was received on this particular candidate pair. + pub last_response_timestamp: Instant, + + /// The sum of all round trip time measurements in seconds since the beginning of the session, + /// based on STUN connectivity check responses (responses_received), including those that reply + /// to requests that are sent in order to verify consent. The average round trip time can be + /// computed from total_round_trip_time by dividing it by responses_received. + pub total_round_trip_time: f64, + + /// The latest round trip time measured in seconds, computed from both STUN connectivity checks, + /// including those that are sent for consent verification. + pub current_round_trip_time: f64, + + /// It is calculated by the underlying congestion control by combining the available bitrate for + /// all the outgoing RTP streams using this candidate pair. The bitrate measurement does not + /// count the size of the IP or other transport layers like TCP or UDP. It is similar to the + /// TIAS defined in RFC 3890, i.e., it is measured in bits per second and the bitrate is + /// calculated over a 1 second window. + pub available_outgoing_bitrate: f64, + + /// It is calculated by the underlying congestion control by combining the available bitrate for + /// all the incoming RTP streams using this candidate pair. The bitrate measurement does not + /// count the size of the IP or other transport layers like TCP or UDP. It is similar to the + /// TIAS defined in RFC 3890, i.e., it is measured in bits per second and the bitrate is + /// calculated over a 1 second window. + pub available_incoming_bitrate: f64, + + /// The number of times the circuit breaker is triggered for this particular 5-tuple, + /// ceasing transmission. + pub circuit_breaker_trigger_count: u32, + + /// The total number of connectivity check requests received (including retransmissions). + /// It is impossible for the receiver to tell whether the request was sent in order to check + /// connectivity or check consent, so all connectivity checks requests are counted here. + pub requests_received: u64, + + /// The total number of connectivity check requests sent (not including retransmissions). + pub requests_sent: u64, + + /// The total number of connectivity check responses received. + pub responses_received: u64, + + /// The total number of connectivity check responses sent. Since we cannot distinguish + /// connectivity check requests and consent requests, all responses are counted. + pub responses_sent: u64, + + /// The total number of connectivity check request retransmissions received. + pub retransmissions_received: u64, + + /// The total number of connectivity check request retransmissions sent. + pub retransmissions_sent: u64, + + /// The total number of consent requests sent. + pub consent_requests_sent: u64, + + /// The timestamp at which the latest valid STUN binding response expired. + pub consent_expired_timestamp: Instant, +} + +impl Default for CandidatePairStats { + fn default() -> Self { + Self { + timestamp: Instant::now(), + local_candidate_id: String::new(), + remote_candidate_id: String::new(), + state: CandidatePairState::default(), + nominated: false, + packets_sent: 0, + packets_received: 0, + bytes_sent: 0, + bytes_received: 0, + last_packet_sent_timestamp: Instant::now(), + last_packet_received_timestamp: Instant::now(), + first_request_timestamp: Instant::now(), + last_request_timestamp: Instant::now(), + last_response_timestamp: Instant::now(), + total_round_trip_time: 0.0, + current_round_trip_time: 0.0, + available_outgoing_bitrate: 0.0, + available_incoming_bitrate: 0.0, + circuit_breaker_trigger_count: 0, + requests_received: 0, + requests_sent: 0, + responses_received: 0, + responses_sent: 0, + retransmissions_received: 0, + retransmissions_sent: 0, + consent_requests_sent: 0, + consent_expired_timestamp: Instant::now(), + } + } +} + +/// Contains ICE candidate statistics related to the `ICETransport` objects. +#[derive(Debug, Clone)] +pub struct CandidateStats { + // The timestamp associated with this struct. + pub timestamp: Instant, + + /// The candidate id. + pub id: String, + + /// The type of network interface used by the base of a local candidate (the address the ICE + /// agent sends from). Only present for local candidates; it's not possible to know what type of + /// network interface a remote candidate is using. + /// + /// Note: This stat only tells you about the network interface used by the first "hop"; it's + /// possible that a connection will be bottlenecked by another type of network. For example, + /// when using Wi-Fi tethering, the networkType of the relevant candidate would be "wifi", even + /// when the next hop is over a cellular connection. + pub network_type: NetworkType, + + /// The IP address of the candidate, allowing for IPv4 addresses and IPv6 addresses, but fully + /// qualified domain names (FQDNs) are not allowed. + pub ip: String, + + /// The port number of the candidate. + pub port: u16, + + /// The `Type` field of the ICECandidate. + pub candidate_type: CandidateType, + + /// The `priority` field of the ICECandidate. + pub priority: u32, + + /// The url of the TURN or STUN server indicated in the that translated this IP address. + /// It is the url address surfaced in an PeerConnectionICEEvent. + pub url: String, + + /// The protocol used by the endpoint to communicate with the TURN server. This is only present + /// for local candidates. Valid values for the TURN url protocol is one of udp, tcp, or tls. + pub relay_protocol: String, + + /// It is true if the candidate has been deleted/freed. For host candidates, this means that any + /// network resources (typically a socket) associated with the candidate have been released. For + /// TURN candidates, this means the TURN allocation is no longer active. + /// + /// Only defined for local candidates. For remote candidates, this property is not applicable. + pub deleted: bool, +} + +impl Default for CandidateStats { + fn default() -> Self { + Self { + timestamp: Instant::now(), + id: String::new(), + network_type: NetworkType::default(), + ip: String::new(), + port: 0, + candidate_type: CandidateType::default(), + priority: 0, + url: String::new(), + relay_protocol: String::new(), + deleted: false, + } + } +} + +impl AgentInternal { + /// Returns a list of candidate pair stats. + pub(crate) async fn get_candidate_pairs_stats(&self) -> Vec { + let checklist = self.agent_conn.checklist.lock().await; + let mut res = Vec::with_capacity(checklist.len()); + for cp in &*checklist { + let stat = CandidatePairStats { + timestamp: Instant::now(), + local_candidate_id: cp.local.id(), + remote_candidate_id: cp.remote.id(), + state: cp.state.load(Ordering::SeqCst).into(), + nominated: cp.nominated.load(Ordering::SeqCst), + ..CandidatePairStats::default() + }; + res.push(stat); + } + res + } + + /// Returns a list of local candidates stats. + pub(crate) async fn get_local_candidates_stats(&self) -> Vec { + let local_candidates = self.local_candidates.lock().await; + let mut res = Vec::with_capacity(local_candidates.len()); + for (network_type, local_candidates) in &*local_candidates { + for c in local_candidates { + let stat = CandidateStats { + timestamp: Instant::now(), + id: c.id(), + network_type: *network_type, + ip: c.address(), + port: c.port(), + candidate_type: c.candidate_type(), + priority: c.priority(), + // URL string + relay_protocol: "udp".to_owned(), + // Deleted bool + ..CandidateStats::default() + }; + res.push(stat); + } + } + res + } + + /// Returns a list of remote candidates stats. + pub(crate) async fn get_remote_candidates_stats(&self) -> Vec { + let remote_candidates = self.remote_candidates.lock().await; + let mut res = Vec::with_capacity(remote_candidates.len()); + for (network_type, remote_candidates) in &*remote_candidates { + for c in remote_candidates { + let stat = CandidateStats { + timestamp: Instant::now(), + id: c.id(), + network_type: *network_type, + ip: c.address(), + port: c.port(), + candidate_type: c.candidate_type(), + priority: c.priority(), + // URL string + relay_protocol: "udp".to_owned(), + // Deleted bool + ..CandidateStats::default() + }; + res.push(stat); + } + } + res + } +} diff --git a/reserved/ice/src/agent/agent_test.rs b/reserved/ice/src/agent/agent_test.rs new file mode 100644 index 0000000..3559888 --- /dev/null +++ b/reserved/ice/src/agent/agent_test.rs @@ -0,0 +1,2199 @@ +use std::net::Ipv4Addr; +use std::ops::Sub; +use std::str::FromStr; + +use async_trait::async_trait; +use stun::message::*; +use stun::textattrs::Username; +use util::vnet::*; +use util::Conn; +use waitgroup::{WaitGroup, Worker}; + +use super::agent_vnet_test::*; +use super::*; +use crate::agent::agent_transport_test::pipe; +use crate::candidate::candidate_base::*; +use crate::candidate::candidate_host::*; +use crate::candidate::candidate_peer_reflexive::*; +use crate::candidate::candidate_relay::*; +use crate::candidate::candidate_server_reflexive::*; +use crate::control::AttrControlling; +use crate::priority::PriorityAttr; +use crate::use_candidate::UseCandidateAttr; + +#[tokio::test] +async fn test_pair_search() -> Result<()> { + let config = AgentConfig::default(); + let a = Agent::new(config).await?; + + { + { + let checklist = a.internal.agent_conn.checklist.lock().await; + assert!( + checklist.is_empty(), + "TestPairSearch is only a valid test if a.validPairs is empty on construction" + ); + } + + let cp = a + .internal + .agent_conn + .get_best_available_candidate_pair() + .await; + assert!(cp.is_none(), "No Candidate pairs should exist"); + } + + a.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_pair_priority() -> Result<()> { + let a = Agent::new(AgentConfig::default()).await?; + + let host_config = CandidateHostConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "192.168.1.1".to_owned(), + port: 19216, + component: 1, + ..Default::default() + }, + ..Default::default() + }; + let host_local: Arc = Arc::new(host_config.new_candidate_host()?); + + let relay_config = CandidateRelayConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "1.2.3.4".to_owned(), + port: 12340, + component: 1, + ..Default::default() + }, + rel_addr: "4.3.2.1".to_owned(), + rel_port: 43210, + ..Default::default() + }; + + let relay_remote = relay_config.new_candidate_relay()?; + + let srflx_config = CandidateServerReflexiveConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "10.10.10.2".to_owned(), + port: 19218, + component: 1, + ..Default::default() + }, + rel_addr: "4.3.2.1".to_owned(), + rel_port: 43212, + }; + + let srflx_remote = srflx_config.new_candidate_server_reflexive()?; + + let prflx_config = CandidatePeerReflexiveConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "10.10.10.2".to_owned(), + port: 19217, + component: 1, + ..Default::default() + }, + rel_addr: "4.3.2.1".to_owned(), + rel_port: 43211, + }; + + let prflx_remote = prflx_config.new_candidate_peer_reflexive()?; + + let host_config = CandidateHostConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "1.2.3.5".to_owned(), + port: 12350, + component: 1, + ..Default::default() + }, + ..Default::default() + }; + let host_remote = host_config.new_candidate_host()?; + + let remotes: Vec> = vec![ + Arc::new(relay_remote), + Arc::new(srflx_remote), + Arc::new(prflx_remote), + Arc::new(host_remote), + ]; + + { + for remote in remotes { + if a.internal.find_pair(&host_local, &remote).await.is_none() { + a.internal + .add_pair(host_local.clone(), remote.clone()) + .await; + } + + if let Some(p) = a.internal.find_pair(&host_local, &remote).await { + p.state + .store(CandidatePairState::Succeeded as u8, Ordering::SeqCst); + } + + if let Some(best_pair) = a + .internal + .agent_conn + .get_best_available_candidate_pair() + .await + { + assert_eq!( + best_pair.to_string(), + CandidatePair { + remote: remote.clone(), + local: host_local.clone(), + ..Default::default() + } + .to_string(), + "Unexpected bestPair {best_pair} (expected remote: {remote})", + ); + } else { + panic!("expected Some, but got None"); + } + } + } + + a.close().await?; + Ok(()) +} + +#[tokio::test] +async fn test_agent_get_stats() -> Result<()> { + let (conn_a, conn_b, agent_a, agent_b) = pipe(None, None).await?; + assert_eq!(agent_a.get_bytes_received(), 0); + assert_eq!(agent_a.get_bytes_sent(), 0); + assert_eq!(agent_b.get_bytes_received(), 0); + assert_eq!(agent_b.get_bytes_sent(), 0); + + let _na = conn_a.send(&[0u8; 10]).await?; + let mut buf = vec![0u8; 10]; + let _nb = conn_b.recv(&mut buf).await?; + + assert_eq!(agent_a.get_bytes_received(), 0); + assert_eq!(agent_a.get_bytes_sent(), 10); + + assert_eq!(agent_b.get_bytes_received(), 10); + assert_eq!(agent_b.get_bytes_sent(), 0); + + Ok(()) +} + +#[tokio::test] +async fn test_on_selected_candidate_pair_change() -> Result<()> { + let a = Agent::new(AgentConfig::default()).await?; + let (callback_called_tx, mut callback_called_rx) = mpsc::channel::<()>(1); + let callback_called_tx = Arc::new(Mutex::new(Some(callback_called_tx))); + let cb: OnSelectedCandidatePairChangeHdlrFn = Box::new(move |_, _| { + let callback_called_tx_clone = Arc::clone(&callback_called_tx); + Box::pin(async move { + let mut tx = callback_called_tx_clone.lock().await; + tx.take(); + }) + }); + a.on_selected_candidate_pair_change(cb); + + let host_config = CandidateHostConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "192.168.1.1".to_owned(), + port: 19216, + component: 1, + ..Default::default() + }, + ..Default::default() + }; + let host_local = host_config.new_candidate_host()?; + + let relay_config = CandidateRelayConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "1.2.3.4".to_owned(), + port: 12340, + component: 1, + ..Default::default() + }, + rel_addr: "4.3.2.1".to_owned(), + rel_port: 43210, + ..Default::default() + }; + let relay_remote = relay_config.new_candidate_relay()?; + + // select the pair + let p = Arc::new(CandidatePair::new( + Arc::new(host_local), + Arc::new(relay_remote), + false, + )); + a.internal.set_selected_pair(Some(p)).await; + + // ensure that the callback fired on setting the pair + let _ = callback_called_rx.recv().await; + + a.close().await?; + Ok(()) +} + +#[tokio::test] +async fn test_handle_peer_reflexive_udp_pflx_candidate() -> Result<()> { + let a = Agent::new(AgentConfig::default()).await?; + + let host_config = CandidateHostConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "192.168.0.2".to_owned(), + port: 777, + component: 1, + conn: Some(Arc::new(MockConn {})), + ..Default::default() + }, + ..Default::default() + }; + + let local: Arc = Arc::new(host_config.new_candidate_host()?); + let remote = SocketAddr::from_str("172.17.0.3:999")?; + + let (username, local_pwd, tie_breaker) = { + let ufrag_pwd = a.internal.ufrag_pwd.lock().await; + ( + ufrag_pwd.local_ufrag.to_owned() + ":" + ufrag_pwd.remote_ufrag.as_str(), + ufrag_pwd.local_pwd.clone(), + a.internal.tie_breaker.load(Ordering::SeqCst), + ) + }; + + let mut msg = Message::new(); + msg.build(&[ + Box::new(BINDING_REQUEST), + Box::new(TransactionId::new()), + Box::new(Username::new(ATTR_USERNAME, username)), + Box::new(UseCandidateAttr::new()), + Box::new(AttrControlling(tie_breaker)), + Box::new(PriorityAttr(local.priority())), + Box::new(MessageIntegrity::new_short_term_integrity(local_pwd)), + Box::new(FINGERPRINT), + ])?; + + { + a.internal.handle_inbound(&mut msg, &local, remote).await; + + let remote_candidates = a.internal.remote_candidates.lock().await; + // length of remote candidate list must be one now + assert_eq!( + remote_candidates.len(), + 1, + "failed to add a network type to the remote candidate list" + ); + + // length of remote candidate list for a network type must be 1 + if let Some(cands) = remote_candidates.get(&local.network_type()) { + assert_eq!( + cands.len(), + 1, + "failed to add prflx candidate to remote candidate list" + ); + + let c = &cands[0]; + + assert_eq!( + c.candidate_type(), + CandidateType::PeerReflexive, + "candidate type must be prflx" + ); + + assert_eq!(c.address(), "172.17.0.3", "IP address mismatch"); + + assert_eq!(c.port(), 999, "Port number mismatch"); + } else { + panic!( + "expected non-empty remote candidate for network type {}", + local.network_type() + ); + } + } + + a.close().await?; + Ok(()) +} + +#[tokio::test] +async fn test_handle_peer_reflexive_unknown_remote() -> Result<()> { + let a = Agent::new(AgentConfig::default()).await?; + + let mut tid = TransactionId::default(); + tid.0[..3].copy_from_slice("ABC".as_bytes()); + + let remote_pwd = { + { + let mut pending_binding_requests = a.internal.pending_binding_requests.lock().await; + *pending_binding_requests = vec![BindingRequest { + timestamp: Instant::now(), + transaction_id: tid, + destination: SocketAddr::from_str("0.0.0.0:0")?, + is_use_candidate: false, + }]; + } + let ufrag_pwd = a.internal.ufrag_pwd.lock().await; + ufrag_pwd.remote_pwd.clone() + }; + + let host_config = CandidateHostConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "192.168.0.2".to_owned(), + port: 777, + component: 1, + conn: Some(Arc::new(MockConn {})), + ..Default::default() + }, + ..Default::default() + }; + + let local: Arc = Arc::new(host_config.new_candidate_host()?); + let remote = SocketAddr::from_str("172.17.0.3:999")?; + + let mut msg = Message::new(); + msg.build(&[ + Box::new(BINDING_SUCCESS), + Box::new(tid), + Box::new(MessageIntegrity::new_short_term_integrity(remote_pwd)), + Box::new(FINGERPRINT), + ])?; + + { + a.internal.handle_inbound(&mut msg, &local, remote).await; + + let remote_candidates = a.internal.remote_candidates.lock().await; + assert_eq!( + remote_candidates.len(), + 0, + "unknown remote was able to create a candidate" + ); + } + + a.close().await?; + Ok(()) +} + +//use std::io::Write; + +// Assert that Agent on startup sends message, and doesn't wait for connectivityTicker to fire +#[tokio::test] +async fn test_connectivity_on_startup() -> Result<()> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, log::LevelFilter::Trace) + .init();*/ + + // Create a network with two interfaces + let wan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { + cidr: "0.0.0.0/0".to_owned(), + ..Default::default() + })?)); + + let net0 = Arc::new(net::Net::new(Some(net::NetConfig { + static_ips: vec!["192.168.0.1".to_owned()], + ..Default::default() + }))); + let net1 = Arc::new(net::Net::new(Some(net::NetConfig { + static_ips: vec!["192.168.0.2".to_owned()], + ..Default::default() + }))); + + connect_net2router(&net0, &wan).await?; + connect_net2router(&net1, &wan).await?; + start_router(&wan).await?; + + let (a_notifier, mut a_connected) = on_connected(); + let (b_notifier, mut b_connected) = on_connected(); + + let keepalive_interval = Some(Duration::from_secs(3600)); //time.Hour + let check_interval = Duration::from_secs(3600); //time.Hour + let cfg0 = AgentConfig { + network_types: supported_network_types(), + multicast_dns_mode: MulticastDnsMode::Disabled, + net: Some(net0), + + keepalive_interval, + check_interval, + ..Default::default() + }; + + let a_agent = Arc::new(Agent::new(cfg0).await?); + a_agent.on_connection_state_change(a_notifier); + + let cfg1 = AgentConfig { + network_types: supported_network_types(), + multicast_dns_mode: MulticastDnsMode::Disabled, + net: Some(net1), + + keepalive_interval, + check_interval, + ..Default::default() + }; + + let b_agent = Arc::new(Agent::new(cfg1).await?); + b_agent.on_connection_state_change(b_notifier); + + // Manual signaling + let (a_ufrag, a_pwd) = a_agent.get_local_user_credentials().await; + let (b_ufrag, b_pwd) = b_agent.get_local_user_credentials().await; + + gather_and_exchange_candidates(&a_agent, &b_agent).await?; + + let (accepted_tx, mut accepted_rx) = mpsc::channel::<()>(1); + let (accepting_tx, mut accepting_rx) = mpsc::channel::<()>(1); + let (_a_cancel_tx, a_cancel_rx) = mpsc::channel(1); + let (_b_cancel_tx, b_cancel_rx) = mpsc::channel(1); + + let accepting_tx = Arc::new(Mutex::new(Some(accepting_tx))); + a_agent.on_connection_state_change(Box::new(move |s: ConnectionState| { + let accepted_tx_clone = Arc::clone(&accepting_tx); + Box::pin(async move { + if s == ConnectionState::Checking { + let mut tx = accepted_tx_clone.lock().await; + tx.take(); + } + }) + })); + + tokio::spawn(async move { + let result = a_agent.accept(a_cancel_rx, b_ufrag, b_pwd).await; + assert!(result.is_ok(), "agent accept expected OK"); + drop(accepted_tx); + }); + + let _ = accepting_rx.recv().await; + + let _ = b_agent.dial(b_cancel_rx, a_ufrag, a_pwd).await?; + + // Ensure accepted + let _ = accepted_rx.recv().await; + + // Ensure pair selected + // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair + let _ = a_connected.recv().await; + let _ = b_connected.recv().await; + + { + let mut w = wan.lock().await; + w.stop().await?; + } + + Ok(()) +} + +#[tokio::test] +async fn test_connectivity_lite() -> Result<()> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, log::LevelFilter::Trace) + .init();*/ + + let stun_server_url = Url { + scheme: SchemeType::Stun, + host: "1.2.3.4".to_owned(), + port: 3478, + proto: ProtoType::Udp, + ..Default::default() + }; + + let nat_type = nat::NatType { + mapping_behavior: nat::EndpointDependencyType::EndpointIndependent, + filtering_behavior: nat::EndpointDependencyType::EndpointIndependent, + ..Default::default() + }; + + let v = build_vnet(nat_type, nat_type).await?; + + let (a_notifier, mut a_connected) = on_connected(); + let (b_notifier, mut b_connected) = on_connected(); + + let cfg0 = AgentConfig { + urls: vec![stun_server_url], + network_types: supported_network_types(), + multicast_dns_mode: MulticastDnsMode::Disabled, + net: Some(Arc::clone(&v.net0)), + ..Default::default() + }; + + let a_agent = Arc::new(Agent::new(cfg0).await?); + a_agent.on_connection_state_change(a_notifier); + + let cfg1 = AgentConfig { + urls: vec![], + lite: true, + candidate_types: vec![CandidateType::Host], + network_types: supported_network_types(), + multicast_dns_mode: MulticastDnsMode::Disabled, + net: Some(Arc::clone(&v.net1)), + ..Default::default() + }; + + let b_agent = Arc::new(Agent::new(cfg1).await?); + b_agent.on_connection_state_change(b_notifier); + + let _ = connect_with_vnet(&a_agent, &b_agent).await?; + + // Ensure pair selected + // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair + let _ = a_connected.recv().await; + let _ = b_connected.recv().await; + + v.close().await?; + + Ok(()) +} + +struct MockPacketConn; + +#[async_trait] +impl Conn for MockPacketConn { + async fn connect(&self, _addr: SocketAddr) -> std::result::Result<(), util::Error> { + Ok(()) + } + + async fn recv(&self, _buf: &mut [u8]) -> std::result::Result { + Ok(0) + } + + async fn recv_from( + &self, + _buf: &mut [u8], + ) -> std::result::Result<(usize, SocketAddr), util::Error> { + Ok((0, SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0))) + } + + async fn send(&self, _buf: &[u8]) -> std::result::Result { + Ok(0) + } + + async fn send_to( + &self, + _buf: &[u8], + _target: SocketAddr, + ) -> std::result::Result { + Ok(0) + } + + fn local_addr(&self) -> std::result::Result { + Ok(SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0)) + } + + fn remote_addr(&self) -> Option { + None + } + + async fn close(&self) -> std::result::Result<(), util::Error> { + Ok(()) + } +} + +fn build_msg(c: MessageClass, username: String, key: String) -> Result { + let mut msg = Message::new(); + msg.build(&[ + Box::new(MessageType::new(METHOD_BINDING, c)), + Box::new(TransactionId::new()), + Box::new(Username::new(ATTR_USERNAME, username)), + Box::new(MessageIntegrity::new_short_term_integrity(key)), + Box::new(FINGERPRINT), + ])?; + Ok(msg) +} + +#[tokio::test] +async fn test_inbound_validity() -> Result<()> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, log::LevelFilter::Trace) + .init();*/ + + let remote = SocketAddr::from_str("172.17.0.3:999")?; + let local: Arc = Arc::new( + CandidateHostConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "192.168.0.2".to_owned(), + port: 777, + component: 1, + conn: Some(Arc::new(MockPacketConn {})), + ..Default::default() + }, + ..Default::default() + } + .new_candidate_host()?, + ); + + //"Invalid Binding requests should be discarded" + { + let a = Agent::new(AgentConfig::default()).await?; + + { + let local_pwd = { + let ufrag_pwd = a.internal.ufrag_pwd.lock().await; + ufrag_pwd.local_pwd.clone() + }; + a.internal + .handle_inbound( + &mut build_msg(CLASS_REQUEST, "invalid".to_owned(), local_pwd)?, + &local, + remote, + ) + .await; + { + let remote_candidates = a.internal.remote_candidates.lock().await; + assert_ne!( + remote_candidates.len(), + 1, + "Binding with invalid Username was able to create prflx candidate" + ); + } + + let username = { + let ufrag_pwd = a.internal.ufrag_pwd.lock().await; + format!("{}:{}", ufrag_pwd.local_ufrag, ufrag_pwd.remote_ufrag) + }; + a.internal + .handle_inbound( + &mut build_msg(CLASS_REQUEST, username, "Invalid".to_owned())?, + &local, + remote, + ) + .await; + { + let remote_candidates = a.internal.remote_candidates.lock().await; + assert_ne!( + remote_candidates.len(), + 1, + "Binding with invalid MessageIntegrity was able to create prflx candidate" + ); + } + } + + a.close().await?; + } + + //"Invalid Binding success responses should be discarded" + { + let a = Agent::new(AgentConfig::default()).await?; + + { + let username = { + let ufrag_pwd = a.internal.ufrag_pwd.lock().await; + format!("{}:{}", ufrag_pwd.local_ufrag, ufrag_pwd.remote_ufrag) + }; + a.internal + .handle_inbound( + &mut build_msg(CLASS_SUCCESS_RESPONSE, username, "Invalid".to_owned())?, + &local, + remote, + ) + .await; + { + let remote_candidates = a.internal.remote_candidates.lock().await; + assert_ne!( + remote_candidates.len(), + 1, + "Binding with invalid Username was able to create prflx candidate" + ); + } + } + + a.close().await?; + } + + //"Discard non-binding messages" + { + let a = Agent::new(AgentConfig::default()).await?; + + { + let username = { + let ufrag_pwd = a.internal.ufrag_pwd.lock().await; + format!("{}:{}", ufrag_pwd.local_ufrag, ufrag_pwd.remote_ufrag) + }; + a.internal + .handle_inbound( + &mut build_msg(CLASS_ERROR_RESPONSE, username, "Invalid".to_owned())?, + &local, + remote, + ) + .await; + let remote_candidates = a.internal.remote_candidates.lock().await; + assert_ne!( + remote_candidates.len(), + 1, + "non-binding message was able to create prflxRemote" + ); + } + + a.close().await?; + } + + //"Valid bind request" + { + let a = Agent::new(AgentConfig::default()).await?; + + { + let (username, local_pwd) = { + let ufrag_pwd = a.internal.ufrag_pwd.lock().await; + ( + format!("{}:{}", ufrag_pwd.local_ufrag, ufrag_pwd.remote_ufrag), + ufrag_pwd.local_pwd.clone(), + ) + }; + a.internal + .handle_inbound( + &mut build_msg(CLASS_REQUEST, username, local_pwd)?, + &local, + remote, + ) + .await; + let remote_candidates = a.internal.remote_candidates.lock().await; + assert_eq!( + remote_candidates.len(), + 1, + "Binding with valid values was unable to create prflx candidate" + ); + } + + a.close().await?; + } + + //"Valid bind without fingerprint" + { + let a = Agent::new(AgentConfig::default()).await?; + + { + let (username, local_pwd) = { + let ufrag_pwd = a.internal.ufrag_pwd.lock().await; + ( + format!("{}:{}", ufrag_pwd.local_ufrag, ufrag_pwd.remote_ufrag), + ufrag_pwd.local_pwd.clone(), + ) + }; + + let mut msg = Message::new(); + msg.build(&[ + Box::new(BINDING_REQUEST), + Box::new(TransactionId::new()), + Box::new(Username::new(ATTR_USERNAME, username)), + Box::new(MessageIntegrity::new_short_term_integrity(local_pwd)), + ])?; + + a.internal.handle_inbound(&mut msg, &local, remote).await; + let remote_candidates = a.internal.remote_candidates.lock().await; + assert_eq!( + remote_candidates.len(), + 1, + "Binding with valid values (but no fingerprint) was unable to create prflx candidate" + ); + } + + a.close().await?; + } + + //"Success with invalid TransactionID" + { + let a = Agent::new(AgentConfig::default()).await?; + + { + let remote = SocketAddr::from_str("172.17.0.3:999")?; + + let mut t_id = TransactionId::default(); + t_id.0[..3].copy_from_slice(b"ABC"); + + let remote_pwd = { + let ufrag_pwd = a.internal.ufrag_pwd.lock().await; + ufrag_pwd.remote_pwd.clone() + }; + + let mut msg = Message::new(); + msg.build(&[ + Box::new(BINDING_SUCCESS), + Box::new(t_id), + Box::new(MessageIntegrity::new_short_term_integrity(remote_pwd)), + Box::new(FINGERPRINT), + ])?; + + a.internal.handle_inbound(&mut msg, &local, remote).await; + + { + let remote_candidates = a.internal.remote_candidates.lock().await; + assert_eq!( + remote_candidates.len(), + 0, + "unknown remote was able to create a candidate" + ); + } + } + + a.close().await?; + } + + Ok(()) +} + +#[tokio::test] +async fn test_invalid_agent_starts() -> Result<()> { + let a = Agent::new(AgentConfig::default()).await?; + + let (_cancel_tx1, cancel_rx1) = mpsc::channel(1); + let result = a.dial(cancel_rx1, "".to_owned(), "bar".to_owned()).await; + assert!(result.is_err()); + if let Err(err) = result { + assert_eq!(Error::ErrRemoteUfragEmpty, err); + } + + let (_cancel_tx2, cancel_rx2) = mpsc::channel(1); + let result = a.dial(cancel_rx2, "foo".to_owned(), "".to_owned()).await; + assert!(result.is_err()); + if let Err(err) = result { + assert_eq!(Error::ErrRemotePwdEmpty, err); + } + + let (cancel_tx3, cancel_rx3) = mpsc::channel(1); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(100)).await; + drop(cancel_tx3); + }); + + let result = a.dial(cancel_rx3, "foo".to_owned(), "bar".to_owned()).await; + assert!(result.is_err()); + if let Err(err) = result { + assert_eq!(Error::ErrCanceledByCaller, err); + } + + let (_cancel_tx4, cancel_rx4) = mpsc::channel(1); + let result = a.dial(cancel_rx4, "foo".to_owned(), "bar".to_owned()).await; + assert!(result.is_err()); + if let Err(err) = result { + assert_eq!(Error::ErrMultipleStart, err); + } + + a.close().await?; + + Ok(()) +} + +//use std::io::Write; + +// Assert that Agent emits Connecting/Connected/Disconnected/Failed/Closed messages +#[tokio::test] +async fn test_connection_state_callback() -> Result<()> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, log::LevelFilter::Trace) + .init();*/ + + let disconnected_duration = Duration::from_secs(1); + let failed_duration = Duration::from_secs(1); + let keepalive_interval = Duration::from_secs(0); + + let cfg0 = AgentConfig { + urls: vec![], + network_types: supported_network_types(), + disconnected_timeout: Some(disconnected_duration), + failed_timeout: Some(failed_duration), + keepalive_interval: Some(keepalive_interval), + ..Default::default() + }; + + let cfg1 = AgentConfig { + urls: vec![], + network_types: supported_network_types(), + disconnected_timeout: Some(disconnected_duration), + failed_timeout: Some(failed_duration), + keepalive_interval: Some(keepalive_interval), + ..Default::default() + }; + + let a_agent = Arc::new(Agent::new(cfg0).await?); + let b_agent = Arc::new(Agent::new(cfg1).await?); + + let (is_checking_tx, mut is_checking_rx) = mpsc::channel::<()>(1); + let (is_connected_tx, mut is_connected_rx) = mpsc::channel::<()>(1); + let (is_disconnected_tx, mut is_disconnected_rx) = mpsc::channel::<()>(1); + let (is_failed_tx, mut is_failed_rx) = mpsc::channel::<()>(1); + let (is_closed_tx, mut is_closed_rx) = mpsc::channel::<()>(1); + + let is_checking_tx = Arc::new(Mutex::new(Some(is_checking_tx))); + let is_connected_tx = Arc::new(Mutex::new(Some(is_connected_tx))); + let is_disconnected_tx = Arc::new(Mutex::new(Some(is_disconnected_tx))); + let is_failed_tx = Arc::new(Mutex::new(Some(is_failed_tx))); + let is_closed_tx = Arc::new(Mutex::new(Some(is_closed_tx))); + + a_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { + let is_checking_tx_clone = Arc::clone(&is_checking_tx); + let is_connected_tx_clone = Arc::clone(&is_connected_tx); + let is_disconnected_tx_clone = Arc::clone(&is_disconnected_tx); + let is_failed_tx_clone = Arc::clone(&is_failed_tx); + let is_closed_tx_clone = Arc::clone(&is_closed_tx); + Box::pin(async move { + match c { + ConnectionState::Checking => { + log::debug!("drop is_checking_tx"); + let mut tx = is_checking_tx_clone.lock().await; + tx.take(); + } + ConnectionState::Connected => { + log::debug!("drop is_connected_tx"); + let mut tx = is_connected_tx_clone.lock().await; + tx.take(); + } + ConnectionState::Disconnected => { + log::debug!("drop is_disconnected_tx"); + let mut tx = is_disconnected_tx_clone.lock().await; + tx.take(); + } + ConnectionState::Failed => { + log::debug!("drop is_failed_tx"); + let mut tx = is_failed_tx_clone.lock().await; + tx.take(); + } + ConnectionState::Closed => { + log::debug!("drop is_closed_tx"); + let mut tx = is_closed_tx_clone.lock().await; + tx.take(); + } + _ => {} + }; + }) + })); + + connect_with_vnet(&a_agent, &b_agent).await?; + + log::debug!("wait is_checking_tx"); + let _ = is_checking_rx.recv().await; + log::debug!("wait is_connected_rx"); + let _ = is_connected_rx.recv().await; + log::debug!("wait is_disconnected_rx"); + let _ = is_disconnected_rx.recv().await; + log::debug!("wait is_failed_rx"); + let _ = is_failed_rx.recv().await; + + a_agent.close().await?; + b_agent.close().await?; + + log::debug!("wait is_closed_rx"); + let _ = is_closed_rx.recv().await; + + Ok(()) +} + +#[tokio::test] +async fn test_invalid_gather() -> Result<()> { + //"Gather with no OnCandidate should error" + let a = Agent::new(AgentConfig::default()).await?; + + if let Err(err) = a.gather_candidates() { + assert_eq!( + Error::ErrNoOnCandidateHandler, + err, + "trickle GatherCandidates succeeded without OnCandidate" + ); + } + + a.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_candidate_pair_stats() -> Result<()> { + let a = Agent::new(AgentConfig::default()).await?; + + let host_local: Arc = Arc::new( + CandidateHostConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "192.168.1.1".to_owned(), + port: 19216, + component: 1, + ..Default::default() + }, + ..Default::default() + } + .new_candidate_host()?, + ); + + let relay_remote: Arc = Arc::new( + CandidateRelayConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "1.2.3.4".to_owned(), + port: 2340, + component: 1, + ..Default::default() + }, + rel_addr: "4.3.2.1".to_owned(), + rel_port: 43210, + ..Default::default() + } + .new_candidate_relay()?, + ); + + let srflx_remote: Arc = Arc::new( + CandidateServerReflexiveConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "10.10.10.2".to_owned(), + port: 19218, + component: 1, + ..Default::default() + }, + rel_addr: "4.3.2.1".to_owned(), + rel_port: 43212, + } + .new_candidate_server_reflexive()?, + ); + + let prflx_remote: Arc = Arc::new( + CandidatePeerReflexiveConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "10.10.10.2".to_owned(), + port: 19217, + component: 1, + ..Default::default() + }, + rel_addr: "4.3.2.1".to_owned(), + rel_port: 43211, + } + .new_candidate_peer_reflexive()?, + ); + + let host_remote: Arc = Arc::new( + CandidateHostConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "1.2.3.5".to_owned(), + port: 12350, + component: 1, + ..Default::default() + }, + ..Default::default() + } + .new_candidate_host()?, + ); + + for remote in &[ + Arc::clone(&relay_remote), + Arc::clone(&srflx_remote), + Arc::clone(&prflx_remote), + Arc::clone(&host_remote), + ] { + let p = a.internal.find_pair(&host_local, remote).await; + + if p.is_none() { + a.internal + .add_pair(Arc::clone(&host_local), Arc::clone(remote)) + .await; + } + } + + { + if let Some(p) = a.internal.find_pair(&host_local, &prflx_remote).await { + p.state + .store(CandidatePairState::Failed as u8, Ordering::SeqCst); + } + } + + let stats = a.get_candidate_pairs_stats().await; + assert_eq!(stats.len(), 4, "expected 4 candidate pairs stats"); + + let (mut relay_pair_stat, mut srflx_pair_stat, mut prflx_pair_stat, mut host_pair_stat) = ( + CandidatePairStats::default(), + CandidatePairStats::default(), + CandidatePairStats::default(), + CandidatePairStats::default(), + ); + + for cps in stats { + assert_eq!( + cps.local_candidate_id, + host_local.id(), + "invalid local candidate id" + ); + + if cps.remote_candidate_id == relay_remote.id() { + relay_pair_stat = cps; + } else if cps.remote_candidate_id == srflx_remote.id() { + srflx_pair_stat = cps; + } else if cps.remote_candidate_id == prflx_remote.id() { + prflx_pair_stat = cps; + } else if cps.remote_candidate_id == host_remote.id() { + host_pair_stat = cps; + } else { + panic!("invalid remote candidate ID"); + } + } + + assert_eq!( + relay_pair_stat.remote_candidate_id, + relay_remote.id(), + "missing host-relay pair stat" + ); + assert_eq!( + srflx_pair_stat.remote_candidate_id, + srflx_remote.id(), + "missing host-srflx pair stat" + ); + assert_eq!( + prflx_pair_stat.remote_candidate_id, + prflx_remote.id(), + "missing host-prflx pair stat" + ); + assert_eq!( + host_pair_stat.remote_candidate_id, + host_remote.id(), + "missing host-host pair stat" + ); + assert_eq!( + prflx_pair_stat.state, + CandidatePairState::Failed, + "expected host-prfflx pair to have state failed, it has state {} instead", + prflx_pair_stat.state + ); + + a.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_local_candidate_stats() -> Result<()> { + let a = Agent::new(AgentConfig::default()).await?; + + let host_local: Arc = Arc::new( + CandidateHostConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "192.168.1.1".to_owned(), + port: 19216, + component: 1, + ..Default::default() + }, + ..Default::default() + } + .new_candidate_host()?, + ); + + let srflx_local: Arc = Arc::new( + CandidateServerReflexiveConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "192.168.1.1".to_owned(), + port: 19217, + component: 1, + ..Default::default() + }, + rel_addr: "4.3.2.1".to_owned(), + rel_port: 43212, + } + .new_candidate_server_reflexive()?, + ); + + { + let mut local_candidates = a.internal.local_candidates.lock().await; + local_candidates.insert( + NetworkType::Udp4, + vec![Arc::clone(&host_local), Arc::clone(&srflx_local)], + ); + } + + let local_stats = a.get_local_candidates_stats().await; + assert_eq!( + local_stats.len(), + 2, + "expected 2 local candidates stats, got {} instead", + local_stats.len() + ); + + let (mut host_local_stat, mut srflx_local_stat) = + (CandidateStats::default(), CandidateStats::default()); + for stats in local_stats { + let candidate = if stats.id == host_local.id() { + host_local_stat = stats.clone(); + Arc::clone(&host_local) + } else if stats.id == srflx_local.id() { + srflx_local_stat = stats.clone(); + Arc::clone(&srflx_local) + } else { + panic!("invalid local candidate ID"); + }; + + assert_eq!( + stats.candidate_type, + candidate.candidate_type(), + "invalid stats CandidateType" + ); + assert_eq!( + stats.priority, + candidate.priority(), + "invalid stats CandidateType" + ); + assert_eq!(stats.ip, candidate.address(), "invalid stats IP"); + } + + assert_eq!( + host_local_stat.id, + host_local.id(), + "missing host local stat" + ); + assert_eq!( + srflx_local_stat.id, + srflx_local.id(), + "missing srflx local stat" + ); + + a.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_remote_candidate_stats() -> Result<()> { + let a = Agent::new(AgentConfig::default()).await?; + + let relay_remote: Arc = Arc::new( + CandidateRelayConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "1.2.3.4".to_owned(), + port: 12340, + component: 1, + ..Default::default() + }, + rel_addr: "4.3.2.1".to_owned(), + rel_port: 43210, + ..Default::default() + } + .new_candidate_relay()?, + ); + + let srflx_remote: Arc = Arc::new( + CandidateServerReflexiveConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "10.10.10.2".to_owned(), + port: 19218, + component: 1, + ..Default::default() + }, + rel_addr: "4.3.2.1".to_owned(), + rel_port: 43212, + } + .new_candidate_server_reflexive()?, + ); + + let prflx_remote: Arc = Arc::new( + CandidatePeerReflexiveConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "10.10.10.2".to_owned(), + port: 19217, + component: 1, + ..Default::default() + }, + rel_addr: "4.3.2.1".to_owned(), + rel_port: 43211, + } + .new_candidate_peer_reflexive()?, + ); + + let host_remote: Arc = Arc::new( + CandidateHostConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "1.2.3.5".to_owned(), + port: 12350, + component: 1, + ..Default::default() + }, + ..Default::default() + } + .new_candidate_host()?, + ); + + { + let mut remote_candidates = a.internal.remote_candidates.lock().await; + remote_candidates.insert( + NetworkType::Udp4, + vec![ + Arc::clone(&relay_remote), + Arc::clone(&srflx_remote), + Arc::clone(&prflx_remote), + Arc::clone(&host_remote), + ], + ); + } + + let remote_stats = a.get_remote_candidates_stats().await; + assert_eq!( + remote_stats.len(), + 4, + "expected 4 remote candidates stats, got {} instead", + remote_stats.len() + ); + + let (mut relay_remote_stat, mut srflx_remote_stat, mut prflx_remote_stat, mut host_remote_stat) = ( + CandidateStats::default(), + CandidateStats::default(), + CandidateStats::default(), + CandidateStats::default(), + ); + for stats in remote_stats { + let candidate = if stats.id == relay_remote.id() { + relay_remote_stat = stats.clone(); + Arc::clone(&relay_remote) + } else if stats.id == srflx_remote.id() { + srflx_remote_stat = stats.clone(); + Arc::clone(&srflx_remote) + } else if stats.id == prflx_remote.id() { + prflx_remote_stat = stats.clone(); + Arc::clone(&prflx_remote) + } else if stats.id == host_remote.id() { + host_remote_stat = stats.clone(); + Arc::clone(&host_remote) + } else { + panic!("invalid remote candidate ID"); + }; + + assert_eq!( + stats.candidate_type, + candidate.candidate_type(), + "invalid stats CandidateType" + ); + assert_eq!( + stats.priority, + candidate.priority(), + "invalid stats CandidateType" + ); + assert_eq!(stats.ip, candidate.address(), "invalid stats IP"); + } + + assert_eq!( + relay_remote_stat.id, + relay_remote.id(), + "missing relay remote stat" + ); + assert_eq!( + srflx_remote_stat.id, + srflx_remote.id(), + "missing srflx remote stat" + ); + assert_eq!( + prflx_remote_stat.id, + prflx_remote.id(), + "missing prflx remote stat" + ); + assert_eq!( + host_remote_stat.id, + host_remote.id(), + "missing host remote stat" + ); + + a.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_init_ext_ip_mapping() -> Result<()> { + // a.extIPMapper should be nil by default + let a = Agent::new(AgentConfig::default()).await?; + assert!( + a.ext_ip_mapper.is_none(), + "a.extIPMapper should be none by default" + ); + a.close().await?; + + // a.extIPMapper should be nil when NAT1To1IPs is a non-nil empty array + let a = Agent::new(AgentConfig { + nat_1to1_ips: vec![], + nat_1to1_ip_candidate_type: CandidateType::Host, + ..Default::default() + }) + .await?; + assert!( + a.ext_ip_mapper.is_none(), + "a.extIPMapper should be none by default" + ); + a.close().await?; + + // NewAgent should return an error when 1:1 NAT for host candidate is enabled + // but the candidate type does not appear in the CandidateTypes. + if let Err(err) = Agent::new(AgentConfig { + nat_1to1_ips: vec!["1.2.3.4".to_owned()], + nat_1to1_ip_candidate_type: CandidateType::Host, + candidate_types: vec![CandidateType::Relay], + ..Default::default() + }) + .await + { + assert_eq!( + Error::ErrIneffectiveNat1to1IpMappingHost, + err, + "Unexpected error: {err}" + ); + } else { + panic!("expected error, but got ok"); + } + + // NewAgent should return an error when 1:1 NAT for srflx candidate is enabled + // but the candidate type does not appear in the CandidateTypes. + if let Err(err) = Agent::new(AgentConfig { + nat_1to1_ips: vec!["1.2.3.4".to_owned()], + nat_1to1_ip_candidate_type: CandidateType::ServerReflexive, + candidate_types: vec![CandidateType::Relay], + ..Default::default() + }) + .await + { + assert_eq!( + Error::ErrIneffectiveNat1to1IpMappingSrflx, + err, + "Unexpected error: {err}" + ); + } else { + panic!("expected error, but got ok"); + } + + // NewAgent should return an error when 1:1 NAT for host candidate is enabled + // along with mDNS with MulticastDNSModeQueryAndGather + if let Err(err) = Agent::new(AgentConfig { + nat_1to1_ips: vec!["1.2.3.4".to_owned()], + nat_1to1_ip_candidate_type: CandidateType::Host, + multicast_dns_mode: MulticastDnsMode::QueryAndGather, + ..Default::default() + }) + .await + { + assert_eq!( + Error::ErrMulticastDnsWithNat1to1IpMapping, + err, + "Unexpected error: {err}" + ); + } else { + panic!("expected error, but got ok"); + } + + // NewAgent should return if newExternalIPMapper() returns an error. + if let Err(err) = Agent::new(AgentConfig { + nat_1to1_ips: vec!["bad.2.3.4".to_owned()], // bad IP + nat_1to1_ip_candidate_type: CandidateType::Host, + ..Default::default() + }) + .await + { + assert_eq!( + Error::ErrInvalidNat1to1IpMapping, + err, + "Unexpected error: {err}" + ); + } else { + panic!("expected error, but got ok"); + } + + Ok(()) +} + +#[tokio::test] +async fn test_binding_request_timeout() -> Result<()> { + const EXPECTED_REMOVAL_COUNT: usize = 2; + + let a = Agent::new(AgentConfig::default()).await?; + + let now = Instant::now(); + { + { + let mut pending_binding_requests = a.internal.pending_binding_requests.lock().await; + pending_binding_requests.push(BindingRequest { + timestamp: now, // valid + ..Default::default() + }); + pending_binding_requests.push(BindingRequest { + timestamp: now.sub(Duration::from_millis(3900)), // valid + ..Default::default() + }); + pending_binding_requests.push(BindingRequest { + timestamp: now.sub(Duration::from_millis(4100)), // invalid + ..Default::default() + }); + pending_binding_requests.push(BindingRequest { + timestamp: now.sub(Duration::from_secs(75)), // invalid + ..Default::default() + }); + } + + a.internal.invalidate_pending_binding_requests(now).await; + { + let pending_binding_requests = a.internal.pending_binding_requests.lock().await; + assert_eq!(pending_binding_requests.len(), EXPECTED_REMOVAL_COUNT, "Binding invalidation due to timeout did not remove the correct number of binding requests") + } + } + + a.close().await?; + + Ok(()) +} + +// test_agent_credentials checks if local username fragments and passwords (if set) meet RFC standard +// and ensure it's backwards compatible with previous versions of the pion/ice +#[tokio::test] +async fn test_agent_credentials() -> Result<()> { + // Agent should not require any of the usernames and password to be set + // If set, they should follow the default 16/128 bits random number generator strategy + + let a = Agent::new(AgentConfig::default()).await?; + { + let ufrag_pwd = a.internal.ufrag_pwd.lock().await; + assert!(ufrag_pwd.local_ufrag.as_bytes().len() * 8 >= 24); + assert!(ufrag_pwd.local_pwd.as_bytes().len() * 8 >= 128); + } + a.close().await?; + + // Should honor RFC standards + // Local values MUST be unguessable, with at least 128 bits of + // random number generator output used to generate the password, and + // at least 24 bits of output to generate the username fragment. + + if let Err(err) = Agent::new(AgentConfig { + local_ufrag: "xx".to_owned(), + ..Default::default() + }) + .await + { + assert_eq!(Error::ErrLocalUfragInsufficientBits, err); + } else { + panic!("expected error, but got ok"); + } + + if let Err(err) = Agent::new(AgentConfig { + local_pwd: "xxxxxx".to_owned(), + ..Default::default() + }) + .await + { + assert_eq!(Error::ErrLocalPwdInsufficientBits, err); + } else { + panic!("expected error, but got ok"); + } + + Ok(()) +} + +// Assert that Agent on Failure deletes all existing candidates +// User can then do an ICE Restart to bring agent back +#[tokio::test] +async fn test_connection_state_failed_delete_all_candidates() -> Result<()> { + let one_second = Duration::from_secs(1); + let keepalive_interval = Duration::from_secs(0); + + let cfg0 = AgentConfig { + network_types: supported_network_types(), + disconnected_timeout: Some(one_second), + failed_timeout: Some(one_second), + keepalive_interval: Some(keepalive_interval), + ..Default::default() + }; + let cfg1 = AgentConfig { + network_types: supported_network_types(), + disconnected_timeout: Some(one_second), + failed_timeout: Some(one_second), + keepalive_interval: Some(keepalive_interval), + ..Default::default() + }; + + let a_agent = Arc::new(Agent::new(cfg0).await?); + let b_agent = Arc::new(Agent::new(cfg1).await?); + + let (is_failed_tx, mut is_failed_rx) = mpsc::channel::<()>(1); + let is_failed_tx = Arc::new(Mutex::new(Some(is_failed_tx))); + a_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { + let is_failed_tx_clone = Arc::clone(&is_failed_tx); + Box::pin(async move { + if c == ConnectionState::Failed { + let mut tx = is_failed_tx_clone.lock().await; + tx.take(); + } + }) + })); + + connect_with_vnet(&a_agent, &b_agent).await?; + let _ = is_failed_rx.recv().await; + + { + { + let remote_candidates = a_agent.internal.remote_candidates.lock().await; + assert_eq!(remote_candidates.len(), 0); + } + { + let local_candidates = a_agent.internal.local_candidates.lock().await; + assert_eq!(local_candidates.len(), 0); + } + } + + a_agent.close().await?; + b_agent.close().await?; + + Ok(()) +} + +// Assert that the ICE Agent can go directly from Connecting -> Failed on both sides +#[tokio::test] +async fn test_connection_state_connecting_to_failed() -> Result<()> { + let one_second = Duration::from_secs(1); + let keepalive_interval = Duration::from_secs(0); + + let cfg0 = AgentConfig { + disconnected_timeout: Some(one_second), + failed_timeout: Some(one_second), + keepalive_interval: Some(keepalive_interval), + ..Default::default() + }; + let cfg1 = AgentConfig { + disconnected_timeout: Some(one_second), + failed_timeout: Some(one_second), + keepalive_interval: Some(keepalive_interval), + ..Default::default() + }; + + let a_agent = Arc::new(Agent::new(cfg0).await?); + let b_agent = Arc::new(Agent::new(cfg1).await?); + + let is_failed = WaitGroup::new(); + let is_checking = WaitGroup::new(); + + let connection_state_check = move |wf: Worker, wc: Worker| { + let wf = Arc::new(Mutex::new(Some(wf))); + let wc = Arc::new(Mutex::new(Some(wc))); + let hdlr_fn: OnConnectionStateChangeHdlrFn = Box::new(move |c: ConnectionState| { + let wf_clone = Arc::clone(&wf); + let wc_clone = Arc::clone(&wc); + Box::pin(async move { + if c == ConnectionState::Failed { + let mut f = wf_clone.lock().await; + f.take(); + } else if c == ConnectionState::Checking { + let mut c = wc_clone.lock().await; + c.take(); + } else if c == ConnectionState::Connected || c == ConnectionState::Completed { + panic!("Unexpected ConnectionState: {c}"); + } + }) + }); + hdlr_fn + }; + + let (wf1, wc1) = (is_failed.worker(), is_checking.worker()); + a_agent.on_connection_state_change(connection_state_check(wf1, wc1)); + + let (wf2, wc2) = (is_failed.worker(), is_checking.worker()); + b_agent.on_connection_state_change(connection_state_check(wf2, wc2)); + + let agent_a = Arc::clone(&a_agent); + tokio::spawn(async move { + let (_cancel_tx, cancel_rx) = mpsc::channel(1); + let result = agent_a + .accept(cancel_rx, "InvalidFrag".to_owned(), "InvalidPwd".to_owned()) + .await; + assert!(result.is_err()); + }); + + let agent_b = Arc::clone(&b_agent); + tokio::spawn(async move { + let (_cancel_tx, cancel_rx) = mpsc::channel(1); + let result = agent_b + .dial(cancel_rx, "InvalidFrag".to_owned(), "InvalidPwd".to_owned()) + .await; + assert!(result.is_err()); + }); + + is_checking.wait().await; + is_failed.wait().await; + + a_agent.close().await?; + b_agent.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_agent_restart_during_gather() -> Result<()> { + //"Restart During Gather" + + let agent = Agent::new(AgentConfig::default()).await?; + + agent + .gathering_state + .store(GatheringState::Gathering as u8, Ordering::SeqCst); + + if let Err(err) = agent.restart("".to_owned(), "".to_owned()).await { + assert_eq!(Error::ErrRestartWhenGathering, err); + } else { + panic!("expected error, but got ok"); + } + + agent.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_agent_restart_when_closed() -> Result<()> { + //"Restart When Closed" + + let agent = Agent::new(AgentConfig::default()).await?; + agent.close().await?; + + if let Err(err) = agent.restart("".to_owned(), "".to_owned()).await { + assert_eq!(Error::ErrClosed, err); + } else { + panic!("expected error, but got ok"); + } + + Ok(()) +} + +#[tokio::test] +async fn test_agent_restart_one_side() -> Result<()> { + let one_second = Duration::from_secs(1); + + //"Restart One Side" + let (_, _, agent_a, agent_b) = pipe( + Some(AgentConfig { + disconnected_timeout: Some(one_second), + failed_timeout: Some(one_second), + ..Default::default() + }), + Some(AgentConfig { + disconnected_timeout: Some(one_second), + failed_timeout: Some(one_second), + ..Default::default() + }), + ) + .await?; + + let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); + let cancel_tx = Arc::new(Mutex::new(Some(cancel_tx))); + agent_b.on_connection_state_change(Box::new(move |c: ConnectionState| { + let cancel_tx_clone = Arc::clone(&cancel_tx); + Box::pin(async move { + if c == ConnectionState::Failed || c == ConnectionState::Disconnected { + let mut tx = cancel_tx_clone.lock().await; + tx.take(); + } + }) + })); + + agent_a.restart("".to_owned(), "".to_owned()).await?; + + let _ = cancel_rx.recv().await; + + agent_a.close().await?; + agent_b.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_agent_restart_both_side() -> Result<()> { + let one_second = Duration::from_secs(1); + //"Restart Both Sides" + + // Get all addresses of candidates concatenated + let generate_candidate_address_strings = + |res: Result>>| -> String { + assert!(res.is_ok()); + + let mut out = String::new(); + if let Ok(candidates) = res { + for c in candidates { + out += c.address().as_str(); + out += ":"; + out += c.port().to_string().as_str(); + } + } + out + }; + + // Store the original candidates, confirm that after we reconnect we have new pairs + let (_, _, agent_a, agent_b) = pipe( + Some(AgentConfig { + disconnected_timeout: Some(one_second), + failed_timeout: Some(one_second), + ..Default::default() + }), + Some(AgentConfig { + disconnected_timeout: Some(one_second), + failed_timeout: Some(one_second), + ..Default::default() + }), + ) + .await?; + + let conn_afirst_candidates = + generate_candidate_address_strings(agent_a.get_local_candidates().await); + let conn_bfirst_candidates = + generate_candidate_address_strings(agent_b.get_local_candidates().await); + + let (a_notifier, mut a_connected) = on_connected(); + agent_a.on_connection_state_change(a_notifier); + + let (b_notifier, mut b_connected) = on_connected(); + agent_b.on_connection_state_change(b_notifier); + + // Restart and Re-Signal + agent_a.restart("".to_owned(), "".to_owned()).await?; + agent_b.restart("".to_owned(), "".to_owned()).await?; + + // Exchange Candidates and Credentials + let (ufrag, pwd) = agent_b.get_local_user_credentials().await; + agent_a.set_remote_credentials(ufrag, pwd).await?; + + let (ufrag, pwd) = agent_a.get_local_user_credentials().await; + agent_b.set_remote_credentials(ufrag, pwd).await?; + + gather_and_exchange_candidates(&agent_a, &agent_b).await?; + + // Wait until both have gone back to connected + let _ = a_connected.recv().await; + let _ = b_connected.recv().await; + + // Assert that we have new candiates each time + assert_ne!( + conn_afirst_candidates, + generate_candidate_address_strings(agent_a.get_local_candidates().await) + ); + assert_ne!( + conn_bfirst_candidates, + generate_candidate_address_strings(agent_b.get_local_candidates().await) + ); + + agent_a.close().await?; + agent_b.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_get_remote_credentials() -> Result<()> { + let a = Agent::new(AgentConfig::default()).await?; + + let (remote_ufrag, remote_pwd) = { + let mut ufrag_pwd = a.internal.ufrag_pwd.lock().await; + ufrag_pwd.remote_ufrag = "remoteUfrag".to_owned(); + ufrag_pwd.remote_pwd = "remotePwd".to_owned(); + ( + ufrag_pwd.remote_ufrag.to_owned(), + ufrag_pwd.remote_pwd.to_owned(), + ) + }; + + let (actual_ufrag, actual_pwd) = a.get_remote_user_credentials().await; + + assert_eq!(actual_ufrag, remote_ufrag); + assert_eq!(actual_pwd, remote_pwd); + + a.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_close_in_connection_state_callback() -> Result<()> { + let disconnected_duration = Duration::from_secs(1); + let failed_duration = Duration::from_secs(1); + let keepalive_interval = Duration::from_secs(0); + + let cfg0 = AgentConfig { + urls: vec![], + network_types: supported_network_types(), + disconnected_timeout: Some(disconnected_duration), + failed_timeout: Some(failed_duration), + keepalive_interval: Some(keepalive_interval), + check_interval: Duration::from_millis(500), + ..Default::default() + }; + + let cfg1 = AgentConfig { + urls: vec![], + network_types: supported_network_types(), + disconnected_timeout: Some(disconnected_duration), + failed_timeout: Some(failed_duration), + keepalive_interval: Some(keepalive_interval), + check_interval: Duration::from_millis(500), + ..Default::default() + }; + + let a_agent = Arc::new(Agent::new(cfg0).await?); + let b_agent = Arc::new(Agent::new(cfg1).await?); + + let (is_closed_tx, mut is_closed_rx) = mpsc::channel::<()>(1); + let (is_connected_tx, mut is_connected_rx) = mpsc::channel::<()>(1); + let is_closed_tx = Arc::new(Mutex::new(Some(is_closed_tx))); + let is_connected_tx = Arc::new(Mutex::new(Some(is_connected_tx))); + a_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { + let is_closed_tx_clone = Arc::clone(&is_closed_tx); + let is_connected_tx_clone = Arc::clone(&is_connected_tx); + Box::pin(async move { + if c == ConnectionState::Connected { + let mut tx = is_connected_tx_clone.lock().await; + tx.take(); + } else if c == ConnectionState::Closed { + let mut tx = is_closed_tx_clone.lock().await; + tx.take(); + } + }) + })); + + connect_with_vnet(&a_agent, &b_agent).await?; + + let _ = is_connected_rx.recv().await; + a_agent.close().await?; + + let _ = is_closed_rx.recv().await; + b_agent.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_run_task_in_connection_state_callback() -> Result<()> { + let one_second = Duration::from_secs(1); + let keepalive_interval = Duration::from_secs(0); + + let cfg0 = AgentConfig { + urls: vec![], + network_types: supported_network_types(), + disconnected_timeout: Some(one_second), + failed_timeout: Some(one_second), + keepalive_interval: Some(keepalive_interval), + check_interval: Duration::from_millis(50), + ..Default::default() + }; + + let cfg1 = AgentConfig { + urls: vec![], + network_types: supported_network_types(), + disconnected_timeout: Some(one_second), + failed_timeout: Some(one_second), + keepalive_interval: Some(keepalive_interval), + check_interval: Duration::from_millis(50), + ..Default::default() + }; + + let a_agent = Arc::new(Agent::new(cfg0).await?); + let b_agent = Arc::new(Agent::new(cfg1).await?); + + let (is_complete_tx, mut is_complete_rx) = mpsc::channel::<()>(1); + let is_complete_tx = Arc::new(Mutex::new(Some(is_complete_tx))); + a_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { + let is_complete_tx_clone = Arc::clone(&is_complete_tx); + Box::pin(async move { + if c == ConnectionState::Connected { + let mut tx = is_complete_tx_clone.lock().await; + tx.take(); + } + }) + })); + + connect_with_vnet(&a_agent, &b_agent).await?; + + let _ = is_complete_rx.recv().await; + let _ = a_agent.get_local_user_credentials().await; + a_agent.restart("".to_owned(), "".to_owned()).await?; + + a_agent.close().await?; + b_agent.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_run_task_in_selected_candidate_pair_change_callback() -> Result<()> { + let one_second = Duration::from_secs(1); + let keepalive_interval = Duration::from_secs(0); + + let cfg0 = AgentConfig { + urls: vec![], + network_types: supported_network_types(), + disconnected_timeout: Some(one_second), + failed_timeout: Some(one_second), + keepalive_interval: Some(keepalive_interval), + check_interval: Duration::from_millis(50), + ..Default::default() + }; + + let cfg1 = AgentConfig { + urls: vec![], + network_types: supported_network_types(), + disconnected_timeout: Some(one_second), + failed_timeout: Some(one_second), + keepalive_interval: Some(keepalive_interval), + check_interval: Duration::from_millis(50), + ..Default::default() + }; + + let a_agent = Arc::new(Agent::new(cfg0).await?); + let b_agent = Arc::new(Agent::new(cfg1).await?); + + let (is_tested_tx, mut is_tested_rx) = mpsc::channel::<()>(1); + let is_tested_tx = Arc::new(Mutex::new(Some(is_tested_tx))); + a_agent.on_selected_candidate_pair_change(Box::new( + move |_: &Arc, _: &Arc| { + let is_tested_tx_clone = Arc::clone(&is_tested_tx); + Box::pin(async move { + let mut tx = is_tested_tx_clone.lock().await; + tx.take(); + }) + }, + )); + + let (is_complete_tx, mut is_complete_rx) = mpsc::channel::<()>(1); + let is_complete_tx = Arc::new(Mutex::new(Some(is_complete_tx))); + a_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { + let is_complete_tx_clone = Arc::clone(&is_complete_tx); + Box::pin(async move { + if c == ConnectionState::Connected { + let mut tx = is_complete_tx_clone.lock().await; + tx.take(); + } + }) + })); + + connect_with_vnet(&a_agent, &b_agent).await?; + + let _ = is_complete_rx.recv().await; + let _ = is_tested_rx.recv().await; + + let _ = a_agent.get_local_user_credentials().await; + + a_agent.close().await?; + b_agent.close().await?; + + Ok(()) +} + +// Assert that a Lite agent goes to disconnected and failed +#[tokio::test] +async fn test_lite_lifecycle() -> Result<()> { + let (a_notifier, mut a_connected_rx) = on_connected(); + + let a_agent = Arc::new( + Agent::new(AgentConfig { + network_types: supported_network_types(), + multicast_dns_mode: MulticastDnsMode::Disabled, + ..Default::default() + }) + .await?, + ); + + a_agent.on_connection_state_change(a_notifier); + + let disconnected_duration = Duration::from_secs(1); + let failed_duration = Duration::from_secs(1); + let keepalive_interval = Duration::from_secs(0); + + let b_agent = Arc::new( + Agent::new(AgentConfig { + lite: true, + candidate_types: vec![CandidateType::Host], + network_types: supported_network_types(), + multicast_dns_mode: MulticastDnsMode::Disabled, + disconnected_timeout: Some(disconnected_duration), + failed_timeout: Some(failed_duration), + keepalive_interval: Some(keepalive_interval), + check_interval: Duration::from_millis(500), + ..Default::default() + }) + .await?, + ); + + let (b_connected_tx, mut b_connected_rx) = mpsc::channel::<()>(1); + let (b_disconnected_tx, mut b_disconnected_rx) = mpsc::channel::<()>(1); + let (b_failed_tx, mut b_failed_rx) = mpsc::channel::<()>(1); + let b_connected_tx = Arc::new(Mutex::new(Some(b_connected_tx))); + let b_disconnected_tx = Arc::new(Mutex::new(Some(b_disconnected_tx))); + let b_failed_tx = Arc::new(Mutex::new(Some(b_failed_tx))); + + b_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { + let b_connected_tx_clone = Arc::clone(&b_connected_tx); + let b_disconnected_tx_clone = Arc::clone(&b_disconnected_tx); + let b_failed_tx_clone = Arc::clone(&b_failed_tx); + + Box::pin(async move { + if c == ConnectionState::Connected { + let mut tx = b_connected_tx_clone.lock().await; + tx.take(); + } else if c == ConnectionState::Disconnected { + let mut tx = b_disconnected_tx_clone.lock().await; + tx.take(); + } else if c == ConnectionState::Failed { + let mut tx = b_failed_tx_clone.lock().await; + tx.take(); + } + }) + })); + + connect_with_vnet(&b_agent, &a_agent).await?; + + let _ = a_connected_rx.recv().await; + let _ = b_connected_rx.recv().await; + a_agent.close().await?; + + let _ = b_disconnected_rx.recv().await; + let _ = b_failed_rx.recv().await; + + b_agent.close().await?; + + Ok(()) +} diff --git a/reserved/ice/src/agent/agent_transport.rs b/reserved/ice/src/agent/agent_transport.rs new file mode 100644 index 0000000..e2cdc56 --- /dev/null +++ b/reserved/ice/src/agent/agent_transport.rs @@ -0,0 +1,246 @@ +use std::io; +use std::sync::atomic::{AtomicBool, Ordering}; + +use arc_swap::ArcSwapOption; +use async_trait::async_trait; +use util::Conn; + +use super::*; +use crate::error::*; + +impl Agent { + /// Connects to the remote agent, acting as the controlling ice agent. + /// The method blocks until at least one ice candidate pair has successfully connected. + /// + /// The operation will be cancelled if `cancel_rx` either receives a message or its channel + /// closes. + pub async fn dial( + &self, + mut cancel_rx: mpsc::Receiver<()>, + remote_ufrag: String, + remote_pwd: String, + ) -> Result> { + let (on_connected_rx, agent_conn) = { + self.internal + .start_connectivity_checks(true, remote_ufrag, remote_pwd) + .await?; + + let mut on_connected_rx = self.internal.on_connected_rx.lock().await; + ( + on_connected_rx.take(), + Arc::clone(&self.internal.agent_conn), + ) + }; + + if let Some(mut on_connected_rx) = on_connected_rx { + // block until pair selected + tokio::select! { + _ = on_connected_rx.recv() => {}, + _ = cancel_rx.recv() => { + return Err(Error::ErrCanceledByCaller); + } + } + } + Ok(agent_conn) + } + + /// Connects to the remote agent, acting as the controlled ice agent. + /// The method blocks until at least one ice candidate pair has successfully connected. + /// + /// The operation will be cancelled if `cancel_rx` either receives a message or its channel + /// closes. + pub async fn accept( + &self, + mut cancel_rx: mpsc::Receiver<()>, + remote_ufrag: String, + remote_pwd: String, + ) -> Result> { + let (on_connected_rx, agent_conn) = { + self.internal + .start_connectivity_checks(false, remote_ufrag, remote_pwd) + .await?; + + let mut on_connected_rx = self.internal.on_connected_rx.lock().await; + ( + on_connected_rx.take(), + Arc::clone(&self.internal.agent_conn), + ) + }; + + if let Some(mut on_connected_rx) = on_connected_rx { + // block until pair selected + tokio::select! { + _ = on_connected_rx.recv() => {}, + _ = cancel_rx.recv() => { + return Err(Error::ErrCanceledByCaller); + } + } + } + + Ok(agent_conn) + } +} + +pub(crate) struct AgentConn { + pub(crate) selected_pair: ArcSwapOption, + pub(crate) checklist: Mutex>>, + + pub(crate) buffer: Buffer, + pub(crate) bytes_received: AtomicUsize, + pub(crate) bytes_sent: AtomicUsize, + pub(crate) done: AtomicBool, +} + +impl AgentConn { + pub(crate) fn new() -> Self { + Self { + selected_pair: ArcSwapOption::empty(), + checklist: Mutex::new(vec![]), + // Make sure the buffer doesn't grow indefinitely. + // NOTE: We actually won't get anywhere close to this limit. + // SRTP will constantly read from the endpoint and drop packets if it's full. + buffer: Buffer::new(0, MAX_BUFFER_SIZE), + bytes_received: AtomicUsize::new(0), + bytes_sent: AtomicUsize::new(0), + done: AtomicBool::new(false), + } + } + pub(crate) fn get_selected_pair(&self) -> Option> { + self.selected_pair.load().clone() + } + + pub(crate) async fn get_best_available_candidate_pair(&self) -> Option> { + let mut best: Option<&Arc> = None; + + let checklist = self.checklist.lock().await; + for p in &*checklist { + if p.state.load(Ordering::SeqCst) == CandidatePairState::Failed as u8 { + continue; + } + + if let Some(b) = &mut best { + if b.priority() < p.priority() { + *b = p; + } + } else { + best = Some(p); + } + } + + best.cloned() + } + + pub(crate) async fn get_best_valid_candidate_pair(&self) -> Option> { + let mut best: Option<&Arc> = None; + + let checklist = self.checklist.lock().await; + for p in &*checklist { + if p.state.load(Ordering::SeqCst) != CandidatePairState::Succeeded as u8 { + continue; + } + + if let Some(b) = &mut best { + if b.priority() < p.priority() { + *b = p; + } + } else { + best = Some(p); + } + } + + best.cloned() + } + + /// Returns the number of bytes sent. + pub fn bytes_sent(&self) -> usize { + self.bytes_sent.load(Ordering::SeqCst) + } + + /// Returns the number of bytes received. + pub fn bytes_received(&self) -> usize { + self.bytes_received.load(Ordering::SeqCst) + } +} + +#[async_trait] +impl Conn for AgentConn { + async fn connect(&self, _addr: SocketAddr) -> std::result::Result<(), util::Error> { + Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) + } + + async fn recv(&self, buf: &mut [u8]) -> std::result::Result { + if self.done.load(Ordering::SeqCst) { + return Err(io::Error::new(io::ErrorKind::Other, "Conn is closed").into()); + } + + let n = match self.buffer.read(buf, None).await { + Ok(n) => n, + Err(err) => return Err(io::Error::new(io::ErrorKind::Other, err.to_string()).into()), + }; + self.bytes_received.fetch_add(n, Ordering::SeqCst); + + Ok(n) + } + + async fn recv_from( + &self, + buf: &mut [u8], + ) -> std::result::Result<(usize, SocketAddr), util::Error> { + if let Some(raddr) = self.remote_addr() { + let n = self.recv(buf).await?; + Ok((n, raddr)) + } else { + Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) + } + } + + async fn send(&self, buf: &[u8]) -> std::result::Result { + if self.done.load(Ordering::SeqCst) { + return Err(io::Error::new(io::ErrorKind::Other, "Conn is closed").into()); + } + + if is_message(buf) { + return Err(util::Error::Other("ErrIceWriteStunMessage".into())); + } + + let result = if let Some(pair) = self.get_selected_pair() { + pair.write(buf).await + } else if let Some(pair) = self.get_best_available_candidate_pair().await { + pair.write(buf).await + } else { + Ok(0) + }; + + match result { + Ok(n) => { + self.bytes_sent.fetch_add(buf.len(), Ordering::SeqCst); + Ok(n) + } + Err(err) => Err(io::Error::new(io::ErrorKind::Other, err.to_string()).into()), + } + } + + async fn send_to( + &self, + _buf: &[u8], + _target: SocketAddr, + ) -> std::result::Result { + Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) + } + + fn local_addr(&self) -> std::result::Result { + if let Some(pair) = self.get_selected_pair() { + Ok(pair.local.addr()) + } else { + Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "Addr Not Available").into()) + } + } + + fn remote_addr(&self) -> Option { + self.get_selected_pair().map(|pair| pair.remote.addr()) + } + + async fn close(&self) -> std::result::Result<(), util::Error> { + Ok(()) + } +} diff --git a/reserved/ice/src/agent/agent_transport_test.rs b/reserved/ice/src/agent/agent_transport_test.rs new file mode 100644 index 0000000..8d4a801 --- /dev/null +++ b/reserved/ice/src/agent/agent_transport_test.rs @@ -0,0 +1,133 @@ +use util::vnet::*; +use util::Conn; +use waitgroup::WaitGroup; + +use super::agent_vnet_test::*; +use super::*; +use crate::agent::agent_transport::AgentConn; + +pub(crate) async fn pipe( + default_config0: Option, + default_config1: Option, +) -> Result<(Arc, Arc, Arc, Arc)> { + let (a_notifier, mut a_connected) = on_connected(); + let (b_notifier, mut b_connected) = on_connected(); + + let mut cfg0 = if let Some(cfg) = default_config0 { + cfg + } else { + AgentConfig::default() + }; + cfg0.urls = vec![]; + cfg0.network_types = supported_network_types(); + + let a_agent = Arc::new(Agent::new(cfg0).await?); + a_agent.on_connection_state_change(a_notifier); + + let mut cfg1 = if let Some(cfg) = default_config1 { + cfg + } else { + AgentConfig::default() + }; + cfg1.urls = vec![]; + cfg1.network_types = supported_network_types(); + + let b_agent = Arc::new(Agent::new(cfg1).await?); + b_agent.on_connection_state_change(b_notifier); + + let (a_conn, b_conn) = connect_with_vnet(&a_agent, &b_agent).await?; + + // Ensure pair selected + // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair + let _ = a_connected.recv().await; + let _ = b_connected.recv().await; + + Ok((a_conn, b_conn, a_agent, b_agent)) +} + +#[tokio::test] +async fn test_remote_local_addr() -> Result<()> { + // Agent0 is behind 1:1 NAT + let nat_type0 = nat::NatType { + mode: nat::NatMode::Nat1To1, + ..Default::default() + }; + // Agent1 is behind 1:1 NAT + let nat_type1 = nat::NatType { + mode: nat::NatMode::Nat1To1, + ..Default::default() + }; + + let v = build_vnet(nat_type0, nat_type1).await?; + + let stun_server_url = Url { + scheme: SchemeType::Stun, + host: VNET_STUN_SERVER_IP.to_owned(), + port: VNET_STUN_SERVER_PORT, + proto: ProtoType::Udp, + ..Default::default() + }; + + //"Disconnected Returns nil" + { + let disconnected_conn = AgentConn::new(); + let result = disconnected_conn.local_addr(); + assert!(result.is_err(), "Disconnected Returns nil"); + } + + //"Remote/Local Pair Match between Agents" + { + let (ca, cb) = pipe_with_vnet( + &v, + AgentTestConfig { + urls: vec![stun_server_url.clone()], + ..Default::default() + }, + AgentTestConfig { + urls: vec![stun_server_url], + ..Default::default() + }, + ) + .await?; + + let a_laddr = ca.local_addr()?; + let b_laddr = cb.local_addr()?; + + // Assert addresses + assert_eq!(a_laddr.ip().to_string(), VNET_LOCAL_IPA.to_string()); + assert_eq!(b_laddr.ip().to_string(), VNET_LOCAL_IPB.to_string()); + + // Close + //ca.close().await?; + //cb.close().await?; + } + + v.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_conn_stats() -> Result<()> { + let (ca, cb, _, _) = pipe(None, None).await?; + let na = ca.send(&[0u8; 10]).await?; + + let wg = WaitGroup::new(); + + let w = wg.worker(); + tokio::spawn(async move { + let _d = w; + + let mut buf = vec![0u8; 10]; + let nb = cb.recv(&mut buf).await?; + assert_eq!(nb, 10, "bytes received don't match"); + + Result::<()>::Ok(()) + }); + + wg.wait().await; + + assert_eq!(na, 10, "bytes sent don't match"); + + Ok(()) +} diff --git a/reserved/ice/src/agent/agent_vnet_test.rs b/reserved/ice/src/agent/agent_vnet_test.rs new file mode 100644 index 0000000..0a58712 --- /dev/null +++ b/reserved/ice/src/agent/agent_vnet_test.rs @@ -0,0 +1,1016 @@ +use std::net::{IpAddr, Ipv4Addr}; +use std::result::Result; +use std::str::FromStr; +use std::sync::atomic::AtomicU64; + +use async_trait::async_trait; +use util::vnet::chunk::Chunk; +use util::vnet::router::Nic; +use util::vnet::*; +use util::Conn; +use waitgroup::WaitGroup; + +use super::*; +use crate::candidate::candidate_base::unmarshal_candidate; + +pub(crate) struct MockConn; + +#[async_trait] +impl Conn for MockConn { + async fn connect(&self, _addr: SocketAddr) -> Result<(), util::Error> { + Ok(()) + } + async fn recv(&self, _buf: &mut [u8]) -> Result { + Ok(0) + } + async fn recv_from(&self, _buf: &mut [u8]) -> Result<(usize, SocketAddr), util::Error> { + Ok((0, SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0))) + } + async fn send(&self, _buf: &[u8]) -> Result { + Ok(0) + } + async fn send_to(&self, _buf: &[u8], _target: SocketAddr) -> Result { + Ok(0) + } + fn local_addr(&self) -> Result { + Ok(SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0)) + } + fn remote_addr(&self) -> Option { + None + } + async fn close(&self) -> Result<(), util::Error> { + Ok(()) + } +} + +pub(crate) struct VNet { + pub(crate) wan: Arc>, + pub(crate) net0: Arc, + pub(crate) net1: Arc, + pub(crate) server: turn::server::Server, +} + +impl VNet { + pub(crate) async fn close(&self) -> Result<(), Error> { + self.server.close().await?; + let mut w = self.wan.lock().await; + w.stop().await?; + Ok(()) + } +} + +pub(crate) const VNET_GLOBAL_IPA: &str = "27.1.1.1"; +pub(crate) const VNET_LOCAL_IPA: &str = "192.168.0.1"; +pub(crate) const VNET_LOCAL_SUBNET_MASK_A: &str = "24"; +pub(crate) const VNET_GLOBAL_IPB: &str = "28.1.1.1"; +pub(crate) const VNET_LOCAL_IPB: &str = "10.2.0.1"; +pub(crate) const VNET_LOCAL_SUBNET_MASK_B: &str = "24"; +pub(crate) const VNET_STUN_SERVER_IP: &str = "1.2.3.4"; +pub(crate) const VNET_STUN_SERVER_PORT: u16 = 3478; + +pub(crate) async fn build_simple_vnet( + _nat_type0: nat::NatType, + _nat_type1: nat::NatType, +) -> Result { + // WAN + let wan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { + cidr: "0.0.0.0/0".to_owned(), + ..Default::default() + })?)); + + let wnet = Arc::new(net::Net::new(Some(net::NetConfig { + static_ip: VNET_STUN_SERVER_IP.to_owned(), // will be assigned to eth0 + ..Default::default() + }))); + + connect_net2router(&wnet, &wan).await?; + + // LAN + let lan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { + cidr: format!("{VNET_LOCAL_IPA}/{VNET_LOCAL_SUBNET_MASK_A}"), + ..Default::default() + })?)); + + let net0 = Arc::new(net::Net::new(Some(net::NetConfig { + static_ips: vec!["192.168.0.1".to_owned()], + ..Default::default() + }))); + let net1 = Arc::new(net::Net::new(Some(net::NetConfig { + static_ips: vec!["192.168.0.2".to_owned()], + ..Default::default() + }))); + + connect_net2router(&net0, &lan).await?; + connect_net2router(&net1, &lan).await?; + connect_router2router(&lan, &wan).await?; + + // start routers... + start_router(&wan).await?; + + let server = add_vnet_stun(wnet).await?; + + Ok(VNet { + wan, + net0, + net1, + server, + }) +} + +pub(crate) async fn build_vnet( + nat_type0: nat::NatType, + nat_type1: nat::NatType, +) -> Result { + // WAN + let wan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { + cidr: "0.0.0.0/0".to_owned(), + ..Default::default() + })?)); + + let wnet = Arc::new(net::Net::new(Some(net::NetConfig { + static_ip: VNET_STUN_SERVER_IP.to_owned(), // will be assigned to eth0 + ..Default::default() + }))); + + connect_net2router(&wnet, &wan).await?; + + // LAN 0 + let lan0 = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { + static_ips: if nat_type0.mode == nat::NatMode::Nat1To1 { + vec![format!("{VNET_GLOBAL_IPA}/{VNET_LOCAL_IPA}")] + } else { + vec![VNET_GLOBAL_IPA.to_owned()] + }, + cidr: format!("{VNET_LOCAL_IPA}/{VNET_LOCAL_SUBNET_MASK_A}"), + nat_type: Some(nat_type0), + ..Default::default() + })?)); + + let net0 = Arc::new(net::Net::new(Some(net::NetConfig { + static_ips: vec![VNET_LOCAL_IPA.to_owned()], + ..Default::default() + }))); + + connect_net2router(&net0, &lan0).await?; + connect_router2router(&lan0, &wan).await?; + + // LAN 1 + let lan1 = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { + static_ips: if nat_type1.mode == nat::NatMode::Nat1To1 { + vec![format!("{VNET_GLOBAL_IPB}/{VNET_LOCAL_IPB}")] + } else { + vec![VNET_GLOBAL_IPB.to_owned()] + }, + cidr: format!("{VNET_LOCAL_IPB}/{VNET_LOCAL_SUBNET_MASK_B}"), + nat_type: Some(nat_type1), + ..Default::default() + })?)); + + let net1 = Arc::new(net::Net::new(Some(net::NetConfig { + static_ips: vec![VNET_LOCAL_IPB.to_owned()], + ..Default::default() + }))); + + connect_net2router(&net1, &lan1).await?; + connect_router2router(&lan1, &wan).await?; + + // start routers... + start_router(&wan).await?; + + let server = add_vnet_stun(wnet).await?; + + Ok(VNet { + wan, + net0, + net1, + server, + }) +} + +pub(crate) struct TestAuthHandler { + pub(crate) cred_map: HashMap>, +} + +impl TestAuthHandler { + pub(crate) fn new() -> Self { + let mut cred_map = HashMap::new(); + cred_map.insert( + "user".to_owned(), + turn::auth::generate_auth_key("user", "webrtc.rs", "pass"), + ); + + TestAuthHandler { cred_map } + } +} + +impl turn::auth::AuthHandler for TestAuthHandler { + fn auth_handle( + &self, + username: &str, + _realm: &str, + _src_addr: SocketAddr, + ) -> Result, turn::Error> { + if let Some(pw) = self.cred_map.get(username) { + Ok(pw.to_vec()) + } else { + Err(turn::Error::Other("fake error".to_owned())) + } + } +} + +pub(crate) async fn add_vnet_stun(wan_net: Arc) -> Result { + // Run TURN(STUN) server + let conn = wan_net + .bind(SocketAddr::from_str(&format!( + "{VNET_STUN_SERVER_IP}:{VNET_STUN_SERVER_PORT}" + ))?) + .await?; + + let server = turn::server::Server::new(turn::server::config::ServerConfig { + conn_configs: vec![turn::server::config::ConnConfig { + conn, + relay_addr_generator: Box::new( + turn::relay::relay_static::RelayAddressGeneratorStatic { + relay_address: IpAddr::from_str(VNET_STUN_SERVER_IP)?, + address: "0.0.0.0".to_owned(), + net: wan_net, + }, + ), + }], + realm: "webrtc.rs".to_owned(), + auth_handler: Arc::new(TestAuthHandler::new()), + channel_bind_timeout: Duration::from_secs(0), + //alloc_close_notify: None, + }) + .await?; + + Ok(server) +} + +pub(crate) async fn connect_with_vnet( + a_agent: &Arc, + b_agent: &Arc, +) -> Result<(Arc, Arc), Error> { + // Manual signaling + let (a_ufrag, a_pwd) = a_agent.get_local_user_credentials().await; + let (b_ufrag, b_pwd) = b_agent.get_local_user_credentials().await; + + gather_and_exchange_candidates(a_agent, b_agent).await?; + + let (accepted_tx, mut accepted_rx) = mpsc::channel(1); + let (_a_cancel_tx, a_cancel_rx) = mpsc::channel(1); + + let agent_a = Arc::clone(a_agent); + tokio::spawn(async move { + let a_conn = agent_a.accept(a_cancel_rx, b_ufrag, b_pwd).await?; + + let _ = accepted_tx.send(a_conn).await; + + Result::<(), Error>::Ok(()) + }); + + let (_b_cancel_tx, b_cancel_rx) = mpsc::channel(1); + let b_conn = b_agent.dial(b_cancel_rx, a_ufrag, a_pwd).await?; + + // Ensure accepted + if let Some(a_conn) = accepted_rx.recv().await { + Ok((a_conn, b_conn)) + } else { + Err(Error::Other("no a_conn".to_owned())) + } +} + +#[derive(Default)] +pub(crate) struct AgentTestConfig { + pub(crate) urls: Vec, + pub(crate) nat_1to1_ip_candidate_type: CandidateType, +} + +pub(crate) async fn pipe_with_vnet( + v: &VNet, + a0test_config: AgentTestConfig, + a1test_config: AgentTestConfig, +) -> Result<(Arc, Arc), Error> { + let (a_notifier, mut a_connected) = on_connected(); + let (b_notifier, mut b_connected) = on_connected(); + + let nat_1to1_ips = if a0test_config.nat_1to1_ip_candidate_type != CandidateType::Unspecified { + vec![VNET_GLOBAL_IPA.to_owned()] + } else { + vec![] + }; + + let cfg0 = AgentConfig { + urls: a0test_config.urls, + network_types: supported_network_types(), + multicast_dns_mode: MulticastDnsMode::Disabled, + nat_1to1_ips, + nat_1to1_ip_candidate_type: a0test_config.nat_1to1_ip_candidate_type, + net: Some(Arc::clone(&v.net0)), + ..Default::default() + }; + + let a_agent = Arc::new(Agent::new(cfg0).await?); + a_agent.on_connection_state_change(a_notifier); + + let nat_1to1_ips = if a1test_config.nat_1to1_ip_candidate_type != CandidateType::Unspecified { + vec![VNET_GLOBAL_IPB.to_owned()] + } else { + vec![] + }; + let cfg1 = AgentConfig { + urls: a1test_config.urls, + network_types: supported_network_types(), + multicast_dns_mode: MulticastDnsMode::Disabled, + nat_1to1_ips, + nat_1to1_ip_candidate_type: a1test_config.nat_1to1_ip_candidate_type, + net: Some(Arc::clone(&v.net1)), + ..Default::default() + }; + + let b_agent = Arc::new(Agent::new(cfg1).await?); + b_agent.on_connection_state_change(b_notifier); + + let (a_conn, b_conn) = connect_with_vnet(&a_agent, &b_agent).await?; + + // Ensure pair selected + // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair + let _ = a_connected.recv().await; + let _ = b_connected.recv().await; + + Ok((a_conn, b_conn)) +} + +pub(crate) fn on_connected() -> (OnConnectionStateChangeHdlrFn, mpsc::Receiver<()>) { + let (done_tx, done_rx) = mpsc::channel::<()>(1); + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + let hdlr_fn: OnConnectionStateChangeHdlrFn = Box::new(move |state: ConnectionState| { + let done_tx_clone = Arc::clone(&done_tx); + Box::pin(async move { + if state == ConnectionState::Connected { + let mut tx = done_tx_clone.lock().await; + tx.take(); + } + }) + }); + (hdlr_fn, done_rx) +} + +pub(crate) async fn gather_and_exchange_candidates( + a_agent: &Arc, + b_agent: &Arc, +) -> Result<(), Error> { + let wg = WaitGroup::new(); + + let w1 = Arc::new(Mutex::new(Some(wg.worker()))); + a_agent.on_candidate(Box::new( + move |candidate: Option>| { + let w3 = Arc::clone(&w1); + Box::pin(async move { + if candidate.is_none() { + let mut w = w3.lock().await; + w.take(); + } + }) + }, + )); + a_agent.gather_candidates()?; + + let w2 = Arc::new(Mutex::new(Some(wg.worker()))); + b_agent.on_candidate(Box::new( + move |candidate: Option>| { + let w3 = Arc::clone(&w2); + Box::pin(async move { + if candidate.is_none() { + let mut w = w3.lock().await; + w.take(); + } + }) + }, + )); + b_agent.gather_candidates()?; + + wg.wait().await; + + let candidates = a_agent.get_local_candidates().await?; + for c in candidates { + let c2: Arc = + Arc::new(unmarshal_candidate(c.marshal().as_str())?); + b_agent.add_remote_candidate(&c2)?; + } + + let candidates = b_agent.get_local_candidates().await?; + for c in candidates { + let c2: Arc = + Arc::new(unmarshal_candidate(c.marshal().as_str())?); + a_agent.add_remote_candidate(&c2)?; + } + + Ok(()) +} + +pub(crate) async fn start_router(router: &Arc>) -> Result<(), Error> { + let mut w = router.lock().await; + Ok(w.start().await?) +} + +pub(crate) async fn connect_net2router( + net: &Arc, + router: &Arc>, +) -> Result<(), Error> { + let nic = net.get_nic()?; + + { + let mut w = router.lock().await; + w.add_net(Arc::clone(&nic)).await?; + } + { + let n = nic.lock().await; + n.set_router(Arc::clone(router)).await?; + } + + Ok(()) +} + +pub(crate) async fn connect_router2router( + child: &Arc>, + parent: &Arc>, +) -> Result<(), Error> { + { + let mut w = parent.lock().await; + w.add_router(Arc::clone(child)).await?; + } + + { + let l = child.lock().await; + l.set_router(Arc::clone(parent)).await?; + } + + Ok(()) +} + +#[tokio::test] +async fn test_connectivity_simple_vnet_full_cone_nats_on_both_ends() -> Result<(), Error> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, log::LevelFilter::Trace) + .init();*/ + + let stun_server_url = Url { + scheme: SchemeType::Stun, + host: VNET_STUN_SERVER_IP.to_owned(), + port: VNET_STUN_SERVER_PORT, + proto: ProtoType::Udp, + ..Default::default() + }; + + // buildVNet with a Full-cone NATs both LANs + let nat_type = nat::NatType { + mapping_behavior: nat::EndpointDependencyType::EndpointIndependent, + filtering_behavior: nat::EndpointDependencyType::EndpointIndependent, + ..Default::default() + }; + + let v = build_simple_vnet(nat_type, nat_type).await?; + + log::debug!("Connecting..."); + let a0test_config = AgentTestConfig { + urls: vec![stun_server_url.clone()], + ..Default::default() + }; + let a1test_config = AgentTestConfig { + urls: vec![stun_server_url.clone()], + ..Default::default() + }; + let (_ca, _cb) = pipe_with_vnet(&v, a0test_config, a1test_config).await?; + + tokio::time::sleep(Duration::from_secs(1)).await; + + log::debug!("Closing..."); + v.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_connectivity_vnet_full_cone_nats_on_both_ends() -> Result<(), Error> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, log::LevelFilter::Trace) + .init();*/ + + let stun_server_url = Url { + scheme: SchemeType::Stun, + host: VNET_STUN_SERVER_IP.to_owned(), + port: VNET_STUN_SERVER_PORT, + proto: ProtoType::Udp, + ..Default::default() + }; + + let _turn_server_url = Url { + scheme: SchemeType::Turn, + host: VNET_STUN_SERVER_IP.to_owned(), + port: VNET_STUN_SERVER_PORT, + username: "user".to_owned(), + password: "pass".to_owned(), + proto: ProtoType::Udp, + }; + + // buildVNet with a Full-cone NATs both LANs + let nat_type = nat::NatType { + mapping_behavior: nat::EndpointDependencyType::EndpointIndependent, + filtering_behavior: nat::EndpointDependencyType::EndpointIndependent, + ..Default::default() + }; + + let v = build_vnet(nat_type, nat_type).await?; + + log::debug!("Connecting..."); + let a0test_config = AgentTestConfig { + urls: vec![stun_server_url.clone()], + ..Default::default() + }; + let a1test_config = AgentTestConfig { + urls: vec![stun_server_url.clone()], + ..Default::default() + }; + let (_ca, _cb) = pipe_with_vnet(&v, a0test_config, a1test_config).await?; + + tokio::time::sleep(Duration::from_secs(1)).await; + + log::debug!("Closing..."); + v.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_connectivity_vnet_symmetric_nats_on_both_ends() -> Result<(), Error> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, log::LevelFilter::Trace) + .init();*/ + + let stun_server_url = Url { + scheme: SchemeType::Stun, + host: VNET_STUN_SERVER_IP.to_owned(), + port: VNET_STUN_SERVER_PORT, + proto: ProtoType::Udp, + ..Default::default() + }; + + let turn_server_url = Url { + scheme: SchemeType::Turn, + host: VNET_STUN_SERVER_IP.to_owned(), + port: VNET_STUN_SERVER_PORT, + username: "user".to_owned(), + password: "pass".to_owned(), + proto: ProtoType::Udp, + }; + + // buildVNet with a Symmetric NATs for both LANs + let nat_type = nat::NatType { + mapping_behavior: nat::EndpointDependencyType::EndpointAddrPortDependent, + filtering_behavior: nat::EndpointDependencyType::EndpointAddrPortDependent, + ..Default::default() + }; + + let v = build_vnet(nat_type, nat_type).await?; + + log::debug!("Connecting..."); + let a0test_config = AgentTestConfig { + urls: vec![stun_server_url.clone(), turn_server_url.clone()], + ..Default::default() + }; + let a1test_config = AgentTestConfig { + urls: vec![stun_server_url.clone()], + ..Default::default() + }; + let (_ca, _cb) = pipe_with_vnet(&v, a0test_config, a1test_config).await?; + + tokio::time::sleep(Duration::from_secs(1)).await; + + log::debug!("Closing..."); + v.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_connectivity_vnet_1to1_nat_with_host_candidate_vs_symmetric_nats() -> Result<(), Error> +{ + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, log::LevelFilter::Trace) + .init();*/ + + // Agent0 is behind 1:1 NAT + let nat_type0 = nat::NatType { + mode: nat::NatMode::Nat1To1, + ..Default::default() + }; + // Agent1 is behind a symmetric NAT + let nat_type1 = nat::NatType { + mapping_behavior: nat::EndpointDependencyType::EndpointAddrPortDependent, + filtering_behavior: nat::EndpointDependencyType::EndpointAddrPortDependent, + ..Default::default() + }; + log::debug!("natType0: {:?}", nat_type0); + log::debug!("natType1: {:?}", nat_type1); + + let v = build_vnet(nat_type0, nat_type1).await?; + + log::debug!("Connecting..."); + let a0test_config = AgentTestConfig { + urls: vec![], + nat_1to1_ip_candidate_type: CandidateType::Host, // Use 1:1 NAT IP as a host candidate + }; + let a1test_config = AgentTestConfig { + urls: vec![], + ..Default::default() + }; + let (_ca, _cb) = pipe_with_vnet(&v, a0test_config, a1test_config).await?; + + tokio::time::sleep(Duration::from_secs(1)).await; + + log::debug!("Closing..."); + v.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_connectivity_vnet_1to1_nat_with_srflx_candidate_vs_symmetric_nats( +) -> Result<(), Error> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, log::LevelFilter::Trace) + .init();*/ + + // Agent0 is behind 1:1 NAT + let nat_type0 = nat::NatType { + mode: nat::NatMode::Nat1To1, + ..Default::default() + }; + // Agent1 is behind a symmetric NAT + let nat_type1 = nat::NatType { + mapping_behavior: nat::EndpointDependencyType::EndpointAddrPortDependent, + filtering_behavior: nat::EndpointDependencyType::EndpointAddrPortDependent, + ..Default::default() + }; + log::debug!("natType0: {:?}", nat_type0); + log::debug!("natType1: {:?}", nat_type1); + + let v = build_vnet(nat_type0, nat_type1).await?; + + log::debug!("Connecting..."); + let a0test_config = AgentTestConfig { + urls: vec![], + nat_1to1_ip_candidate_type: CandidateType::ServerReflexive, // Use 1:1 NAT IP as a srflx candidate + }; + let a1test_config = AgentTestConfig { + urls: vec![], + ..Default::default() + }; + let (_ca, _cb) = pipe_with_vnet(&v, a0test_config, a1test_config).await?; + + tokio::time::sleep(Duration::from_secs(1)).await; + + log::debug!("Closing..."); + v.close().await?; + + Ok(()) +} + +async fn block_until_state_seen( + expected_state: ConnectionState, + state_queue: &mut mpsc::Receiver, +) { + while let Some(s) = state_queue.recv().await { + if s == expected_state { + return; + } + } +} + +// test_disconnected_to_connected asserts that an agent can go to disconnected, and then return to connected successfully +#[tokio::test] +async fn test_disconnected_to_connected() -> Result<(), Error> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, log::LevelFilter::Trace) + .init();*/ + + // Create a network with two interfaces + let wan = router::Router::new(router::RouterConfig { + cidr: "0.0.0.0/0".to_owned(), + ..Default::default() + })?; + + let drop_all_data = Arc::new(AtomicU64::new(0)); + let drop_all_data2 = Arc::clone(&drop_all_data); + wan.add_chunk_filter(Box::new(move |_c: &(dyn Chunk + Send + Sync)| -> bool { + drop_all_data2.load(Ordering::SeqCst) != 1 + })) + .await; + let wan = Arc::new(Mutex::new(wan)); + + let net0 = Arc::new(net::Net::new(Some(net::NetConfig { + static_ips: vec!["192.168.0.1".to_owned()], + ..Default::default() + }))); + let net1 = Arc::new(net::Net::new(Some(net::NetConfig { + static_ips: vec!["192.168.0.2".to_owned()], + ..Default::default() + }))); + + connect_net2router(&net0, &wan).await?; + connect_net2router(&net1, &wan).await?; + start_router(&wan).await?; + + let disconnected_timeout = Duration::from_secs(1); + let keepalive_interval = Duration::from_millis(20); + + // Create two agents and connect them + let controlling_agent = Arc::new( + Agent::new(AgentConfig { + network_types: supported_network_types(), + multicast_dns_mode: MulticastDnsMode::Disabled, + net: Some(Arc::clone(&net0)), + disconnected_timeout: Some(disconnected_timeout), + keepalive_interval: Some(keepalive_interval), + check_interval: keepalive_interval, + ..Default::default() + }) + .await?, + ); + + let controlled_agent = Arc::new( + Agent::new(AgentConfig { + network_types: supported_network_types(), + multicast_dns_mode: MulticastDnsMode::Disabled, + net: Some(Arc::clone(&net1)), + disconnected_timeout: Some(disconnected_timeout), + keepalive_interval: Some(keepalive_interval), + check_interval: keepalive_interval, + ..Default::default() + }) + .await?, + ); + + let (controlling_state_changes_tx, mut controlling_state_changes_rx) = + mpsc::channel::(100); + let controlling_state_changes_tx = Arc::new(controlling_state_changes_tx); + controlling_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { + let controlling_state_changes_tx_clone = Arc::clone(&controlling_state_changes_tx); + Box::pin(async move { + let _ = controlling_state_changes_tx_clone.try_send(c); + }) + })); + + let (controlled_state_changes_tx, mut controlled_state_changes_rx) = + mpsc::channel::(100); + let controlled_state_changes_tx = Arc::new(controlled_state_changes_tx); + controlled_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { + let controlled_state_changes_tx_clone = Arc::clone(&controlled_state_changes_tx); + Box::pin(async move { + let _ = controlled_state_changes_tx_clone.try_send(c); + }) + })); + + connect_with_vnet(&controlling_agent, &controlled_agent).await?; + + // Assert we have gone to connected + block_until_state_seen( + ConnectionState::Connected, + &mut controlling_state_changes_rx, + ) + .await; + block_until_state_seen(ConnectionState::Connected, &mut controlled_state_changes_rx).await; + + // Drop all packets, and block until we have gone to disconnected + drop_all_data.store(1, Ordering::SeqCst); + block_until_state_seen( + ConnectionState::Disconnected, + &mut controlling_state_changes_rx, + ) + .await; + block_until_state_seen( + ConnectionState::Disconnected, + &mut controlled_state_changes_rx, + ) + .await; + + // Allow all packets through again, block until we have gone to connected + drop_all_data.store(0, Ordering::SeqCst); + block_until_state_seen( + ConnectionState::Connected, + &mut controlling_state_changes_rx, + ) + .await; + block_until_state_seen(ConnectionState::Connected, &mut controlled_state_changes_rx).await; + + { + let mut w = wan.lock().await; + w.stop().await?; + } + + controlling_agent.close().await?; + controlled_agent.close().await?; + + Ok(()) +} + +//use std::io::Write; + +// Agent.Write should use the best valid pair if a selected pair is not yet available +#[tokio::test] +async fn test_write_use_valid_pair() -> Result<(), Error> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, log::LevelFilter::Trace) + .init();*/ + + // Create a network with two interfaces + let wan = router::Router::new(router::RouterConfig { + cidr: "0.0.0.0/0".to_owned(), + ..Default::default() + })?; + + wan.add_chunk_filter(Box::new(move |c: &(dyn Chunk + Send + Sync)| -> bool { + let raw = c.user_data(); + if stun::message::is_message(&raw) { + let mut m = stun::message::Message { + raw, + ..Default::default() + }; + let result = m.decode(); + if result.is_err() | m.contains(stun::attributes::ATTR_USE_CANDIDATE) { + return false; + } + } + + true + })) + .await; + let wan = Arc::new(Mutex::new(wan)); + + let net0 = Arc::new(net::Net::new(Some(net::NetConfig { + static_ips: vec!["192.168.0.1".to_owned()], + ..Default::default() + }))); + let net1 = Arc::new(net::Net::new(Some(net::NetConfig { + static_ips: vec!["192.168.0.2".to_owned()], + ..Default::default() + }))); + + connect_net2router(&net0, &wan).await?; + connect_net2router(&net1, &wan).await?; + start_router(&wan).await?; + + // Create two agents and connect them + let controlling_agent = Arc::new( + Agent::new(AgentConfig { + network_types: supported_network_types(), + multicast_dns_mode: MulticastDnsMode::Disabled, + net: Some(Arc::clone(&net0)), + ..Default::default() + }) + .await?, + ); + + let controlled_agent = Arc::new( + Agent::new(AgentConfig { + network_types: supported_network_types(), + multicast_dns_mode: MulticastDnsMode::Disabled, + net: Some(Arc::clone(&net1)), + ..Default::default() + }) + .await?, + ); + + gather_and_exchange_candidates(&controlling_agent, &controlled_agent).await?; + + let (controlling_ufrag, controlling_pwd) = controlling_agent.get_local_user_credentials().await; + let (controlled_ufrag, controlled_pwd) = controlled_agent.get_local_user_credentials().await; + + let controlling_agent_tx = Arc::clone(&controlling_agent); + tokio::spawn(async move { + let test_message = "Test Message"; + let controlling_agent_conn = { + controlling_agent_tx + .internal + .start_connectivity_checks(true, controlled_ufrag, controlled_pwd) + .await?; + Arc::clone(&controlling_agent_tx.internal.agent_conn) as Arc + }; + + log::debug!("controlling_agent start_connectivity_checks done..."); + loop { + let result = controlling_agent_conn.send(test_message.as_bytes()).await; + if result.is_err() { + break; + } + + tokio::time::sleep(Duration::from_millis(20)).await; + } + + Result::<(), Error>::Ok(()) + }); + + let controlled_agent_conn = { + controlled_agent + .internal + .start_connectivity_checks(false, controlling_ufrag, controlling_pwd) + .await?; + Arc::clone(&controlled_agent.internal.agent_conn) as Arc + }; + + log::debug!("controlled_agent start_connectivity_checks done..."); + + let test_message = "Test Message"; + let mut read_buf = vec![0u8; test_message.as_bytes().len()]; + controlled_agent_conn.recv(&mut read_buf).await?; + + assert_eq!(read_buf, test_message.as_bytes(), "should match"); + + { + let mut w = wan.lock().await; + w.stop().await?; + } + + controlling_agent.close().await?; + controlled_agent.close().await?; + + Ok(()) +} diff --git a/reserved/ice/src/agent/mod.rs b/reserved/ice/src/agent/mod.rs new file mode 100644 index 0000000..d61b392 --- /dev/null +++ b/reserved/ice/src/agent/mod.rs @@ -0,0 +1,516 @@ +#[cfg(test)] +mod agent_gather_test; +#[cfg(test)] +mod agent_test; +#[cfg(test)] +mod agent_transport_test; +#[cfg(test)] +pub(crate) mod agent_vnet_test; + +pub mod agent_config; +pub mod agent_gather; +pub(crate) mod agent_internal; +pub mod agent_selector; +pub mod agent_stats; +pub mod agent_transport; + +use std::collections::HashMap; +use std::future::Future; +use std::net::{Ipv4Addr, SocketAddr}; +use std::pin::Pin; +use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::SystemTime; + +use agent_config::*; +use agent_internal::*; +use agent_stats::*; +use mdns::conn::*; +use stun::agent::*; +use stun::attributes::*; +use stun::fingerprint::*; +use stun::integrity::*; +use stun::message::*; +use stun::xoraddr::*; +use tokio::sync::{broadcast, mpsc, Mutex}; +use tokio::time::{Duration, Instant}; +use util::vnet::net::*; +use util::Buffer; + +use crate::agent::agent_gather::GatherCandidatesInternalParams; +use crate::candidate::*; +use crate::error::*; +use crate::external_ip_mapper::*; +use crate::mdns::*; +use crate::network_type::*; +use crate::rand::*; +use crate::state::*; +use crate::tcp_type::TcpType; +use crate::udp_mux::UDPMux; +use crate::udp_network::UDPNetwork; +use crate::url::*; + +#[derive(Debug, Clone)] +pub(crate) struct BindingRequest { + pub(crate) timestamp: Instant, + pub(crate) transaction_id: TransactionId, + pub(crate) destination: SocketAddr, + pub(crate) is_use_candidate: bool, +} + +impl Default for BindingRequest { + fn default() -> Self { + Self { + timestamp: Instant::now(), + transaction_id: TransactionId::default(), + destination: SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0), + is_use_candidate: false, + } + } +} + +pub type OnConnectionStateChangeHdlrFn = Box< + dyn (FnMut(ConnectionState) -> Pin + Send + 'static>>) + + Send + + Sync, +>; +pub type OnSelectedCandidatePairChangeHdlrFn = Box< + dyn (FnMut( + &Arc, + &Arc, + ) -> Pin + Send + 'static>>) + + Send + + Sync, +>; +pub type OnCandidateHdlrFn = Box< + dyn (FnMut( + Option>, + ) -> Pin + Send + 'static>>) + + Send + + Sync, +>; +pub type GatherCandidateCancelFn = Box; + +struct ChanReceivers { + chan_state_rx: mpsc::Receiver, + chan_candidate_rx: mpsc::Receiver>>, + chan_candidate_pair_rx: mpsc::Receiver<()>, +} + +/// Represents the ICE agent. +pub struct Agent { + pub(crate) internal: Arc, + + pub(crate) udp_network: UDPNetwork, + pub(crate) interface_filter: Arc>, + pub(crate) ip_filter: Arc>, + pub(crate) mdns_mode: MulticastDnsMode, + pub(crate) mdns_name: String, + pub(crate) mdns_conn: Option>, + pub(crate) net: Arc, + + // 1:1 D-NAT IP address mapping + pub(crate) ext_ip_mapper: Arc>, + pub(crate) gathering_state: Arc, //GatheringState, + pub(crate) candidate_types: Vec, + pub(crate) urls: Vec, + pub(crate) network_types: Vec, + + pub(crate) gather_candidate_cancel: Option, +} + +impl Agent { + /// Creates a new Agent. + pub async fn new(config: AgentConfig) -> Result { + let mut mdns_name = config.multicast_dns_host_name.clone(); + if mdns_name.is_empty() { + mdns_name = generate_multicast_dns_name(); + } + + if !mdns_name.ends_with(".local") || mdns_name.split('.').count() != 2 { + return Err(Error::ErrInvalidMulticastDnshostName); + } + + let mdns_mode = config.multicast_dns_mode; + + let mdns_conn = + match create_multicast_dns(mdns_mode, &mdns_name, &config.multicast_dns_dest_addr) { + Ok(c) => c, + Err(err) => { + // Opportunistic mDNS: If we can't open the connection, that's ok: we + // can continue without it. + log::warn!("Failed to initialize mDNS {}: {}", mdns_name, err); + None + } + }; + + let (mut ai, chan_receivers) = AgentInternal::new(&config); + let (chan_state_rx, chan_candidate_rx, chan_candidate_pair_rx) = ( + chan_receivers.chan_state_rx, + chan_receivers.chan_candidate_rx, + chan_receivers.chan_candidate_pair_rx, + ); + + config.init_with_defaults(&mut ai); + + let candidate_types = if config.candidate_types.is_empty() { + default_candidate_types() + } else { + config.candidate_types.clone() + }; + + if ai.lite.load(Ordering::SeqCst) + && (candidate_types.len() != 1 || candidate_types[0] != CandidateType::Host) + { + Self::close_multicast_conn(&mdns_conn).await; + return Err(Error::ErrLiteUsingNonHostCandidates); + } + + if !config.urls.is_empty() + && !contains_candidate_type(CandidateType::ServerReflexive, &candidate_types) + && !contains_candidate_type(CandidateType::Relay, &candidate_types) + { + Self::close_multicast_conn(&mdns_conn).await; + return Err(Error::ErrUselessUrlsProvided); + } + + let ext_ip_mapper = match config.init_ext_ip_mapping(mdns_mode, &candidate_types) { + Ok(ext_ip_mapper) => ext_ip_mapper, + Err(err) => { + Self::close_multicast_conn(&mdns_conn).await; + return Err(err); + } + }; + + let net = if let Some(net) = config.net { + if net.is_virtual() { + log::warn!("vnet is enabled"); + if mdns_mode != MulticastDnsMode::Disabled { + log::warn!("vnet does not support mDNS yet"); + } + } + + net + } else { + Arc::new(Net::new(None)) + }; + + let agent = Self { + udp_network: config.udp_network, + internal: Arc::new(ai), + interface_filter: Arc::clone(&config.interface_filter), + ip_filter: Arc::clone(&config.ip_filter), + mdns_mode, + mdns_name, + mdns_conn, + net, + ext_ip_mapper: Arc::new(ext_ip_mapper), + gathering_state: Arc::new(AtomicU8::new(0)), //GatheringState::New, + candidate_types, + urls: config.urls.clone(), + network_types: config.network_types.clone(), + + gather_candidate_cancel: None, //TODO: add cancel + }; + + agent.internal.start_on_connection_state_change_routine( + chan_state_rx, + chan_candidate_rx, + chan_candidate_pair_rx, + ); + + // Restart is also used to initialize the agent for the first time + if let Err(err) = agent.restart(config.local_ufrag, config.local_pwd).await { + Self::close_multicast_conn(&agent.mdns_conn).await; + let _ = agent.close().await; + return Err(err); + } + + Ok(agent) + } + + pub fn get_bytes_received(&self) -> usize { + self.internal.agent_conn.bytes_received() + } + + pub fn get_bytes_sent(&self) -> usize { + self.internal.agent_conn.bytes_sent() + } + + /// Sets a handler that is fired when the connection state changes. + pub fn on_connection_state_change(&self, f: OnConnectionStateChangeHdlrFn) { + self.internal + .on_connection_state_change_hdlr + .store(Some(Arc::new(Mutex::new(f)))) + } + + /// Sets a handler that is fired when the final candidate pair is selected. + pub fn on_selected_candidate_pair_change(&self, f: OnSelectedCandidatePairChangeHdlrFn) { + self.internal + .on_selected_candidate_pair_change_hdlr + .store(Some(Arc::new(Mutex::new(f)))) + } + + /// Sets a handler that is fired when new candidates gathered. When the gathering process + /// complete the last candidate is nil. + pub fn on_candidate(&self, f: OnCandidateHdlrFn) { + self.internal + .on_candidate_hdlr + .store(Some(Arc::new(Mutex::new(f)))); + } + + /// Adds a new remote candidate. + pub fn add_remote_candidate(&self, c: &Arc) -> Result<()> { + // cannot check for network yet because it might not be applied + // when mDNS hostame is used. + if c.tcp_type() == TcpType::Active { + // TCP Candidates with tcptype active will probe server passive ones, so + // no need to do anything with them. + log::info!("Ignoring remote candidate with tcpType active: {}", c); + return Ok(()); + } + + // If we have a mDNS Candidate lets fully resolve it before adding it locally + if c.candidate_type() == CandidateType::Host && c.address().ends_with(".local") { + if self.mdns_mode == MulticastDnsMode::Disabled { + log::warn!( + "remote mDNS candidate added, but mDNS is disabled: ({})", + c.address() + ); + return Ok(()); + } + + if c.candidate_type() != CandidateType::Host { + return Err(Error::ErrAddressParseFailed); + } + + let ai = Arc::clone(&self.internal); + let host_candidate = Arc::clone(c); + let mdns_conn = self.mdns_conn.clone(); + tokio::spawn(async move { + if let Some(mdns_conn) = mdns_conn { + if let Ok(candidate) = + Self::resolve_and_add_multicast_candidate(mdns_conn, host_candidate).await + { + ai.add_remote_candidate(&candidate).await; + } + } + }); + } else { + let ai = Arc::clone(&self.internal); + let candidate = Arc::clone(c); + tokio::spawn(async move { + ai.add_remote_candidate(&candidate).await; + }); + } + + Ok(()) + } + + /// Returns the local candidates. + pub async fn get_local_candidates(&self) -> Result>> { + let mut res = vec![]; + + { + let local_candidates = self.internal.local_candidates.lock().await; + for candidates in local_candidates.values() { + for candidate in candidates { + res.push(Arc::clone(candidate)); + } + } + } + + Ok(res) + } + + /// Returns the local user credentials. + pub async fn get_local_user_credentials(&self) -> (String, String) { + let ufrag_pwd = self.internal.ufrag_pwd.lock().await; + (ufrag_pwd.local_ufrag.clone(), ufrag_pwd.local_pwd.clone()) + } + + /// Returns the remote user credentials. + pub async fn get_remote_user_credentials(&self) -> (String, String) { + let ufrag_pwd = self.internal.ufrag_pwd.lock().await; + (ufrag_pwd.remote_ufrag.clone(), ufrag_pwd.remote_pwd.clone()) + } + + /// Cleans up the Agent. + pub async fn close(&self) -> Result<()> { + if let Some(gather_candidate_cancel) = &self.gather_candidate_cancel { + gather_candidate_cancel(); + } + + if let UDPNetwork::Muxed(ref udp_mux) = self.udp_network { + let (ufrag, _) = self.get_local_user_credentials().await; + udp_mux.remove_conn_by_ufrag(&ufrag).await; + } + + //FIXME: deadlock here + self.internal.close().await + } + + /// Returns the selected pair or nil if there is none + pub fn get_selected_candidate_pair(&self) -> Option> { + self.internal.agent_conn.get_selected_pair() + } + + /// Sets the credentials of the remote agent. + pub async fn set_remote_credentials( + &self, + remote_ufrag: String, + remote_pwd: String, + ) -> Result<()> { + self.internal + .set_remote_credentials(remote_ufrag, remote_pwd) + .await + } + + /// Restarts the ICE Agent with the provided ufrag/pwd + /// If no ufrag/pwd is provided the Agent will generate one itself. + /// + /// Restart must only be called when `GatheringState` is `GatheringStateComplete` + /// a user must then call `GatherCandidates` explicitly to start generating new ones. + pub async fn restart(&self, mut ufrag: String, mut pwd: String) -> Result<()> { + if ufrag.is_empty() { + ufrag = generate_ufrag(); + } + if pwd.is_empty() { + pwd = generate_pwd(); + } + + if ufrag.len() * 8 < 24 { + return Err(Error::ErrLocalUfragInsufficientBits); + } + if pwd.len() * 8 < 128 { + return Err(Error::ErrLocalPwdInsufficientBits); + } + + if GatheringState::from(self.gathering_state.load(Ordering::SeqCst)) + == GatheringState::Gathering + { + return Err(Error::ErrRestartWhenGathering); + } + self.gathering_state + .store(GatheringState::New as u8, Ordering::SeqCst); + + { + let done_tx = self.internal.done_tx.lock().await; + if done_tx.is_none() { + return Err(Error::ErrClosed); + } + } + + // Clear all agent needed to take back to fresh state + { + let mut ufrag_pwd = self.internal.ufrag_pwd.lock().await; + ufrag_pwd.local_ufrag = ufrag; + ufrag_pwd.local_pwd = pwd; + ufrag_pwd.remote_ufrag = String::new(); + ufrag_pwd.remote_pwd = String::new(); + } + { + let mut pending_binding_requests = self.internal.pending_binding_requests.lock().await; + *pending_binding_requests = vec![]; + } + + { + let mut checklist = self.internal.agent_conn.checklist.lock().await; + *checklist = vec![]; + } + + self.internal.set_selected_pair(None).await; + self.internal.delete_all_candidates().await; + self.internal.start().await; + + // Restart is used by NewAgent. Accept/Connect should be used to move to checking + // for new Agents + if self.internal.connection_state.load(Ordering::SeqCst) != ConnectionState::New as u8 { + self.internal + .update_connection_state(ConnectionState::Checking) + .await; + } + + Ok(()) + } + + /// Initiates the trickle based gathering process. + pub fn gather_candidates(&self) -> Result<()> { + if self.gathering_state.load(Ordering::SeqCst) != GatheringState::New as u8 { + return Err(Error::ErrMultipleGatherAttempted); + } + + if self.internal.on_candidate_hdlr.load().is_none() { + return Err(Error::ErrNoOnCandidateHandler); + } + + if let Some(gather_candidate_cancel) = &self.gather_candidate_cancel { + gather_candidate_cancel(); // Cancel previous gathering routine + } + + //TODO: a.gatherCandidateCancel = cancel + + let params = GatherCandidatesInternalParams { + udp_network: self.udp_network.clone(), + candidate_types: self.candidate_types.clone(), + urls: self.urls.clone(), + network_types: self.network_types.clone(), + mdns_mode: self.mdns_mode, + mdns_name: self.mdns_name.clone(), + net: Arc::clone(&self.net), + interface_filter: self.interface_filter.clone(), + ip_filter: self.ip_filter.clone(), + ext_ip_mapper: Arc::clone(&self.ext_ip_mapper), + agent_internal: Arc::clone(&self.internal), + gathering_state: Arc::clone(&self.gathering_state), + chan_candidate_tx: Arc::clone(&self.internal.chan_candidate_tx), + }; + tokio::spawn(async move { + Self::gather_candidates_internal(params).await; + }); + + Ok(()) + } + + /// Returns a list of candidate pair stats. + pub async fn get_candidate_pairs_stats(&self) -> Vec { + self.internal.get_candidate_pairs_stats().await + } + + /// Returns a list of local candidates stats. + pub async fn get_local_candidates_stats(&self) -> Vec { + self.internal.get_local_candidates_stats().await + } + + /// Returns a list of remote candidates stats. + pub async fn get_remote_candidates_stats(&self) -> Vec { + self.internal.get_remote_candidates_stats().await + } + + async fn resolve_and_add_multicast_candidate( + mdns_conn: Arc, + c: Arc, + ) -> Result> { + //TODO: hook up _close_query_signal_tx to Agent or Candidate's Close signal? + let (_close_query_signal_tx, close_query_signal_rx) = mpsc::channel(1); + let src = match mdns_conn.query(&c.address(), close_query_signal_rx).await { + Ok((_, src)) => src, + Err(err) => { + log::warn!("Failed to discover mDNS candidate {}: {}", c.address(), err); + return Err(err.into()); + } + }; + + c.set_ip(&src.ip())?; + + Ok(c) + } + + async fn close_multicast_conn(mdns_conn: &Option>) { + if let Some(conn) = mdns_conn { + if let Err(err) = conn.close().await { + log::warn!("failed to close mDNS Conn: {}", err); + } + } + } +} diff --git a/reserved/ice/src/candidate/candidate_base.rs b/reserved/ice/src/candidate/candidate_base.rs new file mode 100644 index 0000000..30cb3be --- /dev/null +++ b/reserved/ice/src/candidate/candidate_base.rs @@ -0,0 +1,524 @@ +use std::fmt; +use std::ops::Add; +use std::sync::atomic::{AtomicU16, AtomicU64, AtomicU8, Ordering}; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use async_trait::async_trait; +use crc::{Crc, CRC_32_ISCSI}; +use tokio::sync::{broadcast, Mutex}; +use util::sync::Mutex as SyncMutex; + +use super::*; +use crate::candidate::candidate_host::CandidateHostConfig; +use crate::candidate::candidate_peer_reflexive::CandidatePeerReflexiveConfig; +use crate::candidate::candidate_relay::CandidateRelayConfig; +use crate::candidate::candidate_server_reflexive::CandidateServerReflexiveConfig; +use crate::error::*; +use crate::util::*; + +#[derive(Default)] +pub struct CandidateBaseConfig { + pub candidate_id: String, + pub network: String, + pub address: String, + pub port: u16, + pub component: u16, + pub priority: u32, + pub foundation: String, + pub conn: Option>, + pub initialized_ch: Option>, +} + +pub struct CandidateBase { + pub(crate) id: String, + pub(crate) network_type: AtomicU8, + pub(crate) candidate_type: CandidateType, + + pub(crate) component: AtomicU16, + pub(crate) address: String, + pub(crate) port: u16, + pub(crate) related_address: Option, + pub(crate) tcp_type: TcpType, + + pub(crate) resolved_addr: SyncMutex, + + pub(crate) last_sent: AtomicU64, + pub(crate) last_received: AtomicU64, + + pub(crate) conn: Option>, + pub(crate) closed_ch: Arc>>>, + + pub(crate) foundation_override: String, + pub(crate) priority_override: u32, + + //CandidateHost + pub(crate) network: String, + //CandidateRelay + pub(crate) relay_client: Option>, +} + +impl Default for CandidateBase { + fn default() -> Self { + Self { + id: String::new(), + network_type: AtomicU8::new(0), + candidate_type: CandidateType::default(), + + component: AtomicU16::new(0), + address: String::new(), + port: 0, + related_address: None, + tcp_type: TcpType::default(), + + resolved_addr: SyncMutex::new(SocketAddr::new(IpAddr::from([0, 0, 0, 0]), 0)), + + last_sent: AtomicU64::new(0), + last_received: AtomicU64::new(0), + + conn: None, + closed_ch: Arc::new(Mutex::new(None)), + + foundation_override: String::new(), + priority_override: 0, + network: String::new(), + relay_client: None, + } + } +} + +// String makes the candidateBase printable +impl fmt::Display for CandidateBase { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Some(related_address) = self.related_address() { + write!( + f, + "{} {} {}:{}{}", + self.network_type(), + self.candidate_type(), + self.address(), + self.port(), + related_address, + ) + } else { + write!( + f, + "{} {} {}:{}", + self.network_type(), + self.candidate_type(), + self.address(), + self.port(), + ) + } + } +} + +#[async_trait] +impl Candidate for CandidateBase { + fn foundation(&self) -> String { + if !self.foundation_override.is_empty() { + return self.foundation_override.clone(); + } + + let mut buf = vec![]; + buf.extend_from_slice(self.candidate_type().to_string().as_bytes()); + buf.extend_from_slice(self.address.as_bytes()); + buf.extend_from_slice(self.network_type().to_string().as_bytes()); + + let checksum = Crc::::new(&CRC_32_ISCSI).checksum(&buf); + + format!("{checksum}") + } + + /// Returns Candidate ID. + fn id(&self) -> String { + self.id.clone() + } + + /// Returns candidate component. + fn component(&self) -> u16 { + self.component.load(Ordering::SeqCst) + } + + fn set_component(&self, component: u16) { + self.component.store(component, Ordering::SeqCst); + } + + /// Returns a time indicating the last time this candidate was received. + fn last_received(&self) -> SystemTime { + UNIX_EPOCH.add(Duration::from_nanos( + self.last_received.load(Ordering::SeqCst), + )) + } + + /// Returns a time indicating the last time this candidate was sent. + fn last_sent(&self) -> SystemTime { + UNIX_EPOCH.add(Duration::from_nanos(self.last_sent.load(Ordering::SeqCst))) + } + + /// Returns candidate NetworkType. + fn network_type(&self) -> NetworkType { + NetworkType::from(self.network_type.load(Ordering::SeqCst)) + } + + /// Returns Candidate Address. + fn address(&self) -> String { + self.address.clone() + } + + /// Returns Candidate Port. + fn port(&self) -> u16 { + self.port + } + + /// Computes the priority for this ICE Candidate. + fn priority(&self) -> u32 { + if self.priority_override != 0 { + return self.priority_override; + } + + // The local preference MUST be an integer from 0 (lowest preference) to + // 65535 (highest preference) inclusive. When there is only a single IP + // address, this value SHOULD be set to 65535. If there are multiple + // candidates for a particular component for a particular data stream + // that have the same type, the local preference MUST be unique for each + // one. + (1 << 24) * u32::from(self.candidate_type().preference()) + + (1 << 8) * u32::from(self.local_preference()) + + (256 - u32::from(self.component())) + } + + /// Returns `Option`. + fn related_address(&self) -> Option { + self.related_address.as_ref().cloned() + } + + /// Returns candidate type. + fn candidate_type(&self) -> CandidateType { + self.candidate_type + } + + fn tcp_type(&self) -> TcpType { + self.tcp_type + } + + /// Returns the string representation of the ICECandidate. + fn marshal(&self) -> String { + let mut val = format!( + "{} {} {} {} {} {} typ {}", + self.foundation(), + self.component(), + self.network_type().network_short(), + self.priority(), + self.address(), + self.port(), + self.candidate_type() + ); + + if self.tcp_type != TcpType::Unspecified { + val += format!(" tcptype {}", self.tcp_type()).as_str(); + } + + if let Some(related_address) = self.related_address() { + val += format!( + " raddr {} rport {}", + related_address.address, related_address.port, + ) + .as_str(); + } + + val + } + + fn addr(&self) -> SocketAddr { + *self.resolved_addr.lock() + } + + /// Stops the recvLoop. + async fn close(&self) -> Result<()> { + { + let mut closed_ch = self.closed_ch.lock().await; + if closed_ch.is_none() { + return Err(Error::ErrClosed); + } + closed_ch.take(); + } + + if let Some(relay_client) = &self.relay_client { + let _ = relay_client.close().await; + } + + if let Some(conn) = &self.conn { + let _ = conn.close().await; + } + + Ok(()) + } + + fn seen(&self, outbound: bool) { + let d = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_else(|_| Duration::from_secs(0)); + + if outbound { + self.set_last_sent(d); + } else { + self.set_last_received(d); + } + } + + async fn write_to(&self, raw: &[u8], dst: &(dyn Candidate + Send + Sync)) -> Result { + let n = if let Some(conn) = &self.conn { + let addr = dst.addr(); + conn.send_to(raw, addr).await? + } else { + 0 + }; + self.seen(true); + Ok(n) + } + + /// Used to compare two candidateBases. + fn equal(&self, other: &dyn Candidate) -> bool { + self.network_type() == other.network_type() + && self.candidate_type() == other.candidate_type() + && self.address() == other.address() + && self.port() == other.port() + && self.tcp_type() == other.tcp_type() + && self.related_address() == other.related_address() + } + + fn set_ip(&self, ip: &IpAddr) -> Result<()> { + let network_type = determine_network_type(&self.network, ip)?; + + self.network_type + .store(network_type as u8, Ordering::SeqCst); + + let addr = create_addr(network_type, *ip, self.port); + *self.resolved_addr.lock() = addr; + + Ok(()) + } + + fn get_conn(&self) -> Option<&Arc> { + self.conn.as_ref() + } + + fn get_closed_ch(&self) -> Arc>>> { + self.closed_ch.clone() + } +} + +impl CandidateBase { + pub fn set_last_received(&self, d: Duration) { + #[allow(clippy::cast_possible_truncation)] + self.last_received + .store(d.as_nanos() as u64, Ordering::SeqCst); + } + + pub fn set_last_sent(&self, d: Duration) { + #[allow(clippy::cast_possible_truncation)] + self.last_sent.store(d.as_nanos() as u64, Ordering::SeqCst); + } + + /// Returns the local preference for this candidate. + pub fn local_preference(&self) -> u16 { + if self.network_type().is_tcp() { + // RFC 6544, section 4.2 + // + // In Section 4.1.2.1 of [RFC5245], a recommended formula for UDP ICE + // candidate prioritization is defined. For TCP candidates, the same + // formula and candidate type preferences SHOULD be used, and the + // RECOMMENDED type preferences for the new candidate types defined in + // this document (see Section 5) are 105 for NAT-assisted candidates and + // 75 for UDP-tunneled candidates. + // + // (...) + // + // With TCP candidates, the local preference part of the recommended + // priority formula is updated to also include the directionality + // (active, passive, or simultaneous-open) of the TCP connection. The + // RECOMMENDED local preference is then defined as: + // + // local preference = (2^13) * direction-pref + other-pref + // + // The direction-pref MUST be between 0 and 7 (both inclusive), with 7 + // being the most preferred. The other-pref MUST be between 0 and 8191 + // (both inclusive), with 8191 being the most preferred. It is + // RECOMMENDED that the host, UDP-tunneled, and relayed TCP candidates + // have the direction-pref assigned as follows: 6 for active, 4 for + // passive, and 2 for S-O. For the NAT-assisted and server reflexive + // candidates, the RECOMMENDED values are: 6 for S-O, 4 for active, and + // 2 for passive. + // + // (...) + // + // If any two candidates have the same type-preference and direction- + // pref, they MUST have a unique other-pref. With this specification, + // this usually only happens with multi-homed hosts, in which case + // other-pref is the preference for the particular IP address from which + // the candidate was obtained. When there is only a single IP address, + // this value SHOULD be set to the maximum allowed value (8191). + let other_pref: u16 = 8191; + + let direction_pref: u16 = match self.candidate_type() { + CandidateType::Host | CandidateType::Relay => match self.tcp_type() { + TcpType::Active => 6, + TcpType::Passive => 4, + TcpType::SimultaneousOpen => 2, + TcpType::Unspecified => 0, + }, + CandidateType::PeerReflexive | CandidateType::ServerReflexive => { + match self.tcp_type() { + TcpType::SimultaneousOpen => 6, + TcpType::Active => 4, + TcpType::Passive => 2, + TcpType::Unspecified => 0, + } + } + CandidateType::Unspecified => 0, + }; + + (1 << 13) * direction_pref + other_pref + } else { + DEFAULT_LOCAL_PREFERENCE + } + } +} + +/// Creates a Candidate from its string representation. +pub fn unmarshal_candidate(raw: &str) -> Result { + let split: Vec<&str> = raw.split_whitespace().collect(); + if split.len() < 8 { + return Err(Error::Other(format!( + "{:?} ({})", + Error::ErrAttributeTooShortIceCandidate, + split.len() + ))); + } + + // Foundation + let foundation = split[0].to_owned(); + + // Component + let component: u16 = split[1].parse()?; + + // Network + let network = split[2].to_owned(); + + // Priority + let priority: u32 = split[3].parse()?; + + // Address + let address = split[4].to_owned(); + + // Port + let port: u16 = split[5].parse()?; + + let typ = split[7]; + + let mut rel_addr = String::new(); + let mut rel_port = 0; + let mut tcp_type = TcpType::Unspecified; + + if split.len() > 8 { + let split2 = &split[8..]; + + if split2[0] == "raddr" { + if split2.len() < 4 { + return Err(Error::Other(format!( + "{:?}: incorrect length", + Error::ErrParseRelatedAddr + ))); + } + + // RelatedAddress + rel_addr = split2[1].to_owned(); + + // RelatedPort + rel_port = split2[3].parse()?; + } else if split2[0] == "tcptype" { + if split2.len() < 2 { + return Err(Error::Other(format!( + "{:?}: incorrect length", + Error::ErrParseType + ))); + } + + tcp_type = TcpType::from(split2[1]); + } + } + + match typ { + "host" => { + let config = CandidateHostConfig { + base_config: CandidateBaseConfig { + network, + address, + port, + component, + priority, + foundation, + ..CandidateBaseConfig::default() + }, + tcp_type, + }; + config.new_candidate_host() + } + "srflx" => { + let config = CandidateServerReflexiveConfig { + base_config: CandidateBaseConfig { + network, + address, + port, + component, + priority, + foundation, + ..CandidateBaseConfig::default() + }, + rel_addr, + rel_port, + }; + config.new_candidate_server_reflexive() + } + "prflx" => { + let config = CandidatePeerReflexiveConfig { + base_config: CandidateBaseConfig { + network, + address, + port, + component, + priority, + foundation, + ..CandidateBaseConfig::default() + }, + rel_addr, + rel_port, + }; + + config.new_candidate_peer_reflexive() + } + "relay" => { + let config = CandidateRelayConfig { + base_config: CandidateBaseConfig { + network, + address, + port, + component, + priority, + foundation, + ..CandidateBaseConfig::default() + }, + rel_addr, + rel_port, + ..CandidateRelayConfig::default() + }; + config.new_candidate_relay() + } + _ => Err(Error::Other(format!( + "{:?} ({})", + Error::ErrUnknownCandidateType, + typ + ))), + } +} diff --git a/reserved/ice/src/candidate/candidate_host.rs b/reserved/ice/src/candidate/candidate_host.rs new file mode 100644 index 0000000..a904d9f --- /dev/null +++ b/reserved/ice/src/candidate/candidate_host.rs @@ -0,0 +1,45 @@ +use std::sync::atomic::{AtomicU16, AtomicU8}; + +use super::candidate_base::*; +use super::*; +use crate::rand::generate_cand_id; + +/// The config required to create a new `CandidateHost`. +#[derive(Default)] +pub struct CandidateHostConfig { + pub base_config: CandidateBaseConfig, + + pub tcp_type: TcpType, +} + +impl CandidateHostConfig { + /// Creates a new host candidate. + pub fn new_candidate_host(self) -> Result { + let mut candidate_id = self.base_config.candidate_id; + if candidate_id.is_empty() { + candidate_id = generate_cand_id(); + } + + let c = CandidateBase { + id: candidate_id, + address: self.base_config.address.clone(), + candidate_type: CandidateType::Host, + component: AtomicU16::new(self.base_config.component), + port: self.base_config.port, + tcp_type: self.tcp_type, + foundation_override: self.base_config.foundation, + priority_override: self.base_config.priority, + network: self.base_config.network, + network_type: AtomicU8::new(NetworkType::Udp4 as u8), + conn: self.base_config.conn, + ..CandidateBase::default() + }; + + if !self.base_config.address.ends_with(".local") { + let ip = self.base_config.address.parse()?; + c.set_ip(&ip)?; + }; + + Ok(c) + } +} diff --git a/reserved/ice/src/candidate/candidate_pair_test.rs b/reserved/ice/src/candidate/candidate_pair_test.rs new file mode 100644 index 0000000..7b2765a --- /dev/null +++ b/reserved/ice/src/candidate/candidate_pair_test.rs @@ -0,0 +1,155 @@ +use super::*; +use crate::candidate::candidate_host::CandidateHostConfig; +use crate::candidate::candidate_peer_reflexive::CandidatePeerReflexiveConfig; +use crate::candidate::candidate_relay::CandidateRelayConfig; +use crate::candidate::candidate_server_reflexive::CandidateServerReflexiveConfig; + +pub(crate) fn host_candidate() -> Result { + CandidateHostConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "0.0.0.0".to_owned(), + component: COMPONENT_RTP, + ..Default::default() + }, + ..Default::default() + } + .new_candidate_host() +} + +pub(crate) fn prflx_candidate() -> Result { + CandidatePeerReflexiveConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "0.0.0.0".to_owned(), + component: COMPONENT_RTP, + ..Default::default() + }, + ..Default::default() + } + .new_candidate_peer_reflexive() +} + +pub(crate) fn srflx_candidate() -> Result { + CandidateServerReflexiveConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "0.0.0.0".to_owned(), + component: COMPONENT_RTP, + ..Default::default() + }, + ..Default::default() + } + .new_candidate_server_reflexive() +} + +pub(crate) fn relay_candidate() -> Result { + CandidateRelayConfig { + base_config: CandidateBaseConfig { + network: "udp".to_owned(), + address: "0.0.0.0".to_owned(), + component: COMPONENT_RTP, + ..Default::default() + }, + ..Default::default() + } + .new_candidate_relay() +} + +#[test] +fn test_candidate_pair_priority() -> Result<()> { + let tests = vec![ + ( + CandidatePair::new( + Arc::new(host_candidate()?), + Arc::new(host_candidate()?), + false, + ), + 9151314440652587007, + ), + ( + CandidatePair::new( + Arc::new(host_candidate()?), + Arc::new(host_candidate()?), + true, + ), + 9151314440652587007, + ), + ( + CandidatePair::new( + Arc::new(host_candidate()?), + Arc::new(prflx_candidate()?), + true, + ), + 7998392936314175488, + ), + ( + CandidatePair::new( + Arc::new(host_candidate()?), + Arc::new(prflx_candidate()?), + false, + ), + 7998392936314175487, + ), + ( + CandidatePair::new( + Arc::new(host_candidate()?), + Arc::new(srflx_candidate()?), + true, + ), + 7277816996102668288, + ), + ( + CandidatePair::new( + Arc::new(host_candidate()?), + Arc::new(srflx_candidate()?), + false, + ), + 7277816996102668287, + ), + ( + CandidatePair::new( + Arc::new(host_candidate()?), + Arc::new(relay_candidate()?), + true, + ), + 72057593987596288, + ), + ( + CandidatePair::new( + Arc::new(host_candidate()?), + Arc::new(relay_candidate()?), + false, + ), + 72057593987596287, + ), + ]; + + for (pair, want) in tests { + let got = pair.priority(); + assert_eq!( + got, want, + "CandidatePair({pair}).Priority() = {got}, want {want}" + ); + } + + Ok(()) +} + +#[test] +fn test_candidate_pair_equality() -> Result<()> { + let pair_a = CandidatePair::new( + Arc::new(host_candidate()?), + Arc::new(srflx_candidate()?), + true, + ); + let pair_b = CandidatePair::new( + Arc::new(host_candidate()?), + Arc::new(srflx_candidate()?), + false, + ); + + assert_eq!(pair_a, pair_b, "Expected {pair_a} to equal {pair_b}"); + + Ok(()) +} diff --git a/reserved/ice/src/candidate/candidate_peer_reflexive.rs b/reserved/ice/src/candidate/candidate_peer_reflexive.rs new file mode 100644 index 0000000..b60b9ea --- /dev/null +++ b/reserved/ice/src/candidate/candidate_peer_reflexive.rs @@ -0,0 +1,54 @@ +use std::sync::atomic::{AtomicU16, AtomicU8}; + +use util::sync::Mutex as SyncMutex; + +use super::candidate_base::*; +use super::*; +use crate::error::*; +use crate::rand::generate_cand_id; +use crate::util::*; + +/// The config required to create a new `CandidatePeerReflexive`. +#[derive(Default)] +pub struct CandidatePeerReflexiveConfig { + pub base_config: CandidateBaseConfig, + + pub rel_addr: String, + pub rel_port: u16, +} + +impl CandidatePeerReflexiveConfig { + /// Creates a new peer reflective candidate. + pub fn new_candidate_peer_reflexive(self) -> Result { + let ip: IpAddr = match self.base_config.address.parse() { + Ok(ip) => ip, + Err(_) => return Err(Error::ErrAddressParseFailed), + }; + let network_type = determine_network_type(&self.base_config.network, &ip)?; + + let mut candidate_id = self.base_config.candidate_id; + if candidate_id.is_empty() { + candidate_id = generate_cand_id(); + } + + let c = CandidateBase { + id: candidate_id, + network_type: AtomicU8::new(network_type as u8), + candidate_type: CandidateType::PeerReflexive, + address: self.base_config.address, + port: self.base_config.port, + resolved_addr: SyncMutex::new(create_addr(network_type, ip, self.base_config.port)), + component: AtomicU16::new(self.base_config.component), + foundation_override: self.base_config.foundation, + priority_override: self.base_config.priority, + related_address: Some(CandidateRelatedAddress { + address: self.rel_addr, + port: self.rel_port, + }), + conn: self.base_config.conn, + ..CandidateBase::default() + }; + + Ok(c) + } +} diff --git a/reserved/ice/src/candidate/candidate_relay.rs b/reserved/ice/src/candidate/candidate_relay.rs new file mode 100644 index 0000000..50dd3e7 --- /dev/null +++ b/reserved/ice/src/candidate/candidate_relay.rs @@ -0,0 +1,57 @@ +use std::sync::atomic::{AtomicU16, AtomicU8}; +use std::sync::Arc; + +use util::sync::Mutex as SyncMutex; + +use super::candidate_base::*; +use super::*; +use crate::error::*; +use crate::rand::generate_cand_id; +use crate::util::*; + +/// The config required to create a new `CandidateRelay`. +#[derive(Default)] +pub struct CandidateRelayConfig { + pub base_config: CandidateBaseConfig, + + pub rel_addr: String, + pub rel_port: u16, + pub relay_client: Option>, +} + +impl CandidateRelayConfig { + /// Creates a new relay candidate. + pub fn new_candidate_relay(self) -> Result { + let mut candidate_id = self.base_config.candidate_id; + if candidate_id.is_empty() { + candidate_id = generate_cand_id(); + } + + let ip: IpAddr = match self.base_config.address.parse() { + Ok(ip) => ip, + Err(_) => return Err(Error::ErrAddressParseFailed), + }; + let network_type = determine_network_type(&self.base_config.network, &ip)?; + + let c = CandidateBase { + id: candidate_id, + network_type: AtomicU8::new(network_type as u8), + candidate_type: CandidateType::Relay, + address: self.base_config.address, + port: self.base_config.port, + resolved_addr: SyncMutex::new(create_addr(network_type, ip, self.base_config.port)), + component: AtomicU16::new(self.base_config.component), + foundation_override: self.base_config.foundation, + priority_override: self.base_config.priority, + related_address: Some(CandidateRelatedAddress { + address: self.rel_addr, + port: self.rel_port, + }), + conn: self.base_config.conn, + relay_client: self.relay_client.clone(), + ..CandidateBase::default() + }; + + Ok(c) + } +} diff --git a/reserved/ice/src/candidate/candidate_relay_test.rs b/reserved/ice/src/candidate/candidate_relay_test.rs new file mode 100644 index 0000000..4d6b5bb --- /dev/null +++ b/reserved/ice/src/candidate/candidate_relay_test.rs @@ -0,0 +1,114 @@ +use std::result::Result; +use std::time::Duration; + +use tokio::net::UdpSocket; +use turn::auth::AuthHandler; + +use super::*; +use crate::agent::agent_config::AgentConfig; +use crate::agent::agent_vnet_test::{connect_with_vnet, on_connected}; +use crate::agent::Agent; +use crate::error::Error; +use crate::url::{ProtoType, SchemeType, Url}; + +pub(crate) struct OptimisticAuthHandler; + +impl AuthHandler for OptimisticAuthHandler { + fn auth_handle( + &self, + _username: &str, + _realm: &str, + _src_addr: SocketAddr, + ) -> Result, turn::Error> { + Ok(turn::auth::generate_auth_key( + "username", + "webrtc.rs", + "password", + )) + } +} + +//use std::io::Write; + +#[tokio::test] +async fn test_relay_only_connection() -> Result<(), Error> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, log::LevelFilter::Trace) + .init();*/ + + let server_listener = Arc::new(UdpSocket::bind("127.0.0.1:0").await?); + let server_port = server_listener.local_addr()?.port(); + + let server = turn::server::Server::new(turn::server::config::ServerConfig { + realm: "webrtc.rs".to_owned(), + auth_handler: Arc::new(OptimisticAuthHandler {}), + conn_configs: vec![turn::server::config::ConnConfig { + conn: server_listener, + relay_addr_generator: Box::new(turn::relay::relay_none::RelayAddressGeneratorNone { + address: "127.0.0.1".to_owned(), + net: Arc::new(util::vnet::net::Net::new(None)), + }), + }], + channel_bind_timeout: Duration::from_secs(0), + //alloc_close_notify: None, + }) + .await?; + + let cfg0 = AgentConfig { + network_types: supported_network_types(), + urls: vec![Url { + scheme: SchemeType::Turn, + host: "127.0.0.1".to_owned(), + username: "username".to_owned(), + password: "password".to_owned(), + port: server_port, + proto: ProtoType::Udp, + }], + candidate_types: vec![CandidateType::Relay], + ..Default::default() + }; + + let a_agent = Arc::new(Agent::new(cfg0).await?); + let (a_notifier, mut a_connected) = on_connected(); + a_agent.on_connection_state_change(a_notifier); + + let cfg1 = AgentConfig { + network_types: supported_network_types(), + urls: vec![Url { + scheme: SchemeType::Turn, + host: "127.0.0.1".to_owned(), + username: "username".to_owned(), + password: "password".to_owned(), + port: server_port, + proto: ProtoType::Udp, + }], + candidate_types: vec![CandidateType::Relay], + ..Default::default() + }; + + let b_agent = Arc::new(Agent::new(cfg1).await?); + let (b_notifier, mut b_connected) = on_connected(); + b_agent.on_connection_state_change(b_notifier); + + connect_with_vnet(&a_agent, &b_agent).await?; + + let _ = a_connected.recv().await; + let _ = b_connected.recv().await; + + a_agent.close().await?; + b_agent.close().await?; + server.close().await?; + + Ok(()) +} diff --git a/reserved/ice/src/candidate/candidate_server_reflexive.rs b/reserved/ice/src/candidate/candidate_server_reflexive.rs new file mode 100644 index 0000000..2651e35 --- /dev/null +++ b/reserved/ice/src/candidate/candidate_server_reflexive.rs @@ -0,0 +1,54 @@ +use std::sync::atomic::{AtomicU16, AtomicU8}; + +use util::sync::Mutex as SyncMutex; + +use super::candidate_base::*; +use super::*; +use crate::error::*; +use crate::rand::generate_cand_id; +use crate::util::*; + +/// The config required to create a new `CandidateServerReflexive`. +#[derive(Default)] +pub struct CandidateServerReflexiveConfig { + pub base_config: CandidateBaseConfig, + + pub rel_addr: String, + pub rel_port: u16, +} + +impl CandidateServerReflexiveConfig { + /// Creates a new server reflective candidate. + pub fn new_candidate_server_reflexive(self) -> Result { + let ip: IpAddr = match self.base_config.address.parse() { + Ok(ip) => ip, + Err(_) => return Err(Error::ErrAddressParseFailed), + }; + let network_type = determine_network_type(&self.base_config.network, &ip)?; + + let mut candidate_id = self.base_config.candidate_id; + if candidate_id.is_empty() { + candidate_id = generate_cand_id(); + } + + let c = CandidateBase { + id: candidate_id, + network_type: AtomicU8::new(network_type as u8), + candidate_type: CandidateType::ServerReflexive, + address: self.base_config.address, + port: self.base_config.port, + resolved_addr: SyncMutex::new(create_addr(network_type, ip, self.base_config.port)), + component: AtomicU16::new(self.base_config.component), + foundation_override: self.base_config.foundation, + priority_override: self.base_config.priority, + related_address: Some(CandidateRelatedAddress { + address: self.rel_addr, + port: self.rel_port, + }), + conn: self.base_config.conn, + ..CandidateBase::default() + }; + + Ok(c) + } +} diff --git a/reserved/ice/src/candidate/candidate_server_reflexive_test.rs b/reserved/ice/src/candidate/candidate_server_reflexive_test.rs new file mode 100644 index 0000000..c6690a3 --- /dev/null +++ b/reserved/ice/src/candidate/candidate_server_reflexive_test.rs @@ -0,0 +1,91 @@ +use std::time::Duration; + +use tokio::net::UdpSocket; + +use super::candidate_relay_test::OptimisticAuthHandler; +use super::*; +use crate::agent::agent_config::AgentConfig; +use crate::agent::agent_vnet_test::{connect_with_vnet, on_connected}; +use crate::agent::Agent; +use crate::url::{SchemeType, Url}; + +//use std::io::Write; + +#[tokio::test] +async fn test_server_reflexive_only_connection() -> Result<()> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, log::LevelFilter::Trace) + .init();*/ + + let server_listener = Arc::new(UdpSocket::bind("127.0.0.1:0").await?); + let server_port = server_listener.local_addr()?.port(); + + let server = turn::server::Server::new(turn::server::config::ServerConfig { + realm: "webrtc.rs".to_owned(), + auth_handler: Arc::new(OptimisticAuthHandler {}), + conn_configs: vec![turn::server::config::ConnConfig { + conn: server_listener, + relay_addr_generator: Box::new(turn::relay::relay_none::RelayAddressGeneratorNone { + address: "127.0.0.1".to_owned(), + net: Arc::new(util::vnet::net::Net::new(None)), + }), + }], + channel_bind_timeout: Duration::from_secs(0), + //alloc_close_notify: None, + }) + .await?; + + let cfg0 = AgentConfig { + network_types: vec![NetworkType::Udp4], + urls: vec![Url { + scheme: SchemeType::Stun, + host: "127.0.0.1".to_owned(), + port: server_port, + ..Default::default() + }], + candidate_types: vec![CandidateType::ServerReflexive], + ..Default::default() + }; + + let a_agent = Arc::new(Agent::new(cfg0).await?); + let (a_notifier, mut a_connected) = on_connected(); + a_agent.on_connection_state_change(a_notifier); + + let cfg1 = AgentConfig { + network_types: vec![NetworkType::Udp4], + urls: vec![Url { + scheme: SchemeType::Stun, + host: "127.0.0.1".to_owned(), + port: server_port, + ..Default::default() + }], + candidate_types: vec![CandidateType::ServerReflexive], + ..Default::default() + }; + + let b_agent = Arc::new(Agent::new(cfg1).await?); + let (b_notifier, mut b_connected) = on_connected(); + b_agent.on_connection_state_change(b_notifier); + + connect_with_vnet(&a_agent, &b_agent).await?; + + let _ = a_connected.recv().await; + let _ = b_connected.recv().await; + + a_agent.close().await?; + b_agent.close().await?; + server.close().await?; + + Ok(()) +} diff --git a/reserved/ice/src/candidate/candidate_test.rs b/reserved/ice/src/candidate/candidate_test.rs new file mode 100644 index 0000000..b9f2928 --- /dev/null +++ b/reserved/ice/src/candidate/candidate_test.rs @@ -0,0 +1,411 @@ +use std::time::UNIX_EPOCH; + +use super::*; + +#[test] +fn test_candidate_priority() -> Result<()> { + let tests = vec![ + ( + CandidateBase { + candidate_type: CandidateType::Host, + component: AtomicU16::new(COMPONENT_RTP), + ..Default::default() + }, + 2130706431, + ), + ( + CandidateBase { + candidate_type: CandidateType::Host, + component: AtomicU16::new(COMPONENT_RTP), + network_type: AtomicU8::new(NetworkType::Tcp4 as u8), + tcp_type: TcpType::Active, + ..Default::default() + }, + 2128609279, + ), + ( + CandidateBase { + candidate_type: CandidateType::Host, + component: AtomicU16::new(COMPONENT_RTP), + network_type: AtomicU8::new(NetworkType::Tcp4 as u8), + tcp_type: TcpType::Passive, + ..Default::default() + }, + 2124414975, + ), + ( + CandidateBase { + candidate_type: CandidateType::Host, + component: AtomicU16::new(COMPONENT_RTP), + network_type: AtomicU8::new(NetworkType::Tcp4 as u8), + tcp_type: TcpType::SimultaneousOpen, + ..Default::default() + }, + 2120220671, + ), + ( + CandidateBase { + candidate_type: CandidateType::PeerReflexive, + component: AtomicU16::new(COMPONENT_RTP), + ..Default::default() + }, + 1862270975, + ), + ( + CandidateBase { + candidate_type: CandidateType::PeerReflexive, + component: AtomicU16::new(COMPONENT_RTP), + network_type: AtomicU8::new(NetworkType::Tcp6 as u8), + tcp_type: TcpType::SimultaneousOpen, + ..Default::default() + }, + 1860173823, + ), + ( + CandidateBase { + candidate_type: CandidateType::PeerReflexive, + component: AtomicU16::new(COMPONENT_RTP), + network_type: AtomicU8::new(NetworkType::Tcp6 as u8), + tcp_type: TcpType::Active, + ..Default::default() + }, + 1855979519, + ), + ( + CandidateBase { + candidate_type: CandidateType::PeerReflexive, + component: AtomicU16::new(COMPONENT_RTP), + network_type: AtomicU8::new(NetworkType::Tcp6 as u8), + tcp_type: TcpType::Passive, + ..Default::default() + }, + 1851785215, + ), + ( + CandidateBase { + candidate_type: CandidateType::ServerReflexive, + component: AtomicU16::new(COMPONENT_RTP), + ..Default::default() + }, + 1694498815, + ), + ( + CandidateBase { + candidate_type: CandidateType::Relay, + component: AtomicU16::new(COMPONENT_RTP), + ..Default::default() + }, + 16777215, + ), + ]; + + for (candidate, want) in tests { + let got = candidate.priority(); + assert_eq!( + got, want, + "Candidate({candidate}).Priority() = {got}, want {want}" + ); + } + + Ok(()) +} + +#[test] +fn test_candidate_last_sent() -> Result<()> { + let candidate = CandidateBase::default(); + assert_eq!(candidate.last_sent(), UNIX_EPOCH); + + let now = SystemTime::now(); + let d = now.duration_since(UNIX_EPOCH)?; + candidate.set_last_sent(d); + assert_eq!(candidate.last_sent(), now); + + Ok(()) +} + +#[test] +fn test_candidate_last_received() -> Result<()> { + let candidate = CandidateBase::default(); + assert_eq!(candidate.last_received(), UNIX_EPOCH); + + let now = SystemTime::now(); + let d = now.duration_since(UNIX_EPOCH)?; + candidate.set_last_received(d); + assert_eq!(candidate.last_received(), now); + + Ok(()) +} + +#[test] +fn test_candidate_foundation() -> Result<()> { + // All fields are the same + assert_eq!( + (CandidateBase { + candidate_type: CandidateType::Host, + network_type: AtomicU8::new(NetworkType::Udp4 as u8), + address: "A".to_owned(), + ..Default::default() + }) + .foundation(), + (CandidateBase { + candidate_type: CandidateType::Host, + network_type: AtomicU8::new(NetworkType::Udp4 as u8), + address: "A".to_owned(), + ..Default::default() + }) + .foundation() + ); + + // Different Address + assert_ne!( + (CandidateBase { + candidate_type: CandidateType::Host, + network_type: AtomicU8::new(NetworkType::Udp4 as u8), + address: "A".to_owned(), + ..Default::default() + }) + .foundation(), + (CandidateBase { + candidate_type: CandidateType::Host, + network_type: AtomicU8::new(NetworkType::Udp4 as u8), + address: "B".to_owned(), + ..Default::default() + }) + .foundation(), + ); + + // Different networkType + assert_ne!( + (CandidateBase { + candidate_type: CandidateType::Host, + network_type: AtomicU8::new(NetworkType::Udp4 as u8), + address: "A".to_owned(), + ..Default::default() + }) + .foundation(), + (CandidateBase { + candidate_type: CandidateType::Host, + network_type: AtomicU8::new(NetworkType::Udp6 as u8), + address: "A".to_owned(), + ..Default::default() + }) + .foundation(), + ); + + // Different candidateType + assert_ne!( + (CandidateBase { + candidate_type: CandidateType::Host, + network_type: AtomicU8::new(NetworkType::Udp4 as u8), + address: "A".to_owned(), + ..Default::default() + }) + .foundation(), + (CandidateBase { + candidate_type: CandidateType::PeerReflexive, + network_type: AtomicU8::new(NetworkType::Udp4 as u8), + address: "A".to_owned(), + ..Default::default() + }) + .foundation(), + ); + + // Port has no effect + assert_eq!( + (CandidateBase { + candidate_type: CandidateType::Host, + network_type: AtomicU8::new(NetworkType::Udp4 as u8), + address: "A".to_owned(), + port: 8080, + ..Default::default() + }) + .foundation(), + (CandidateBase { + candidate_type: CandidateType::Host, + network_type: AtomicU8::new(NetworkType::Udp4 as u8), + address: "A".to_owned(), + port: 80, + ..Default::default() + }) + .foundation() + ); + + Ok(()) +} + +#[test] +fn test_candidate_pair_state_serialization() { + let tests = vec![ + (CandidatePairState::Unspecified, "\"unspecified\""), + (CandidatePairState::Waiting, "\"waiting\""), + (CandidatePairState::InProgress, "\"in-progress\""), + (CandidatePairState::Failed, "\"failed\""), + (CandidatePairState::Succeeded, "\"succeeded\""), + ]; + + for (candidate_pair_state, expected_string) in tests { + assert_eq!( + expected_string.to_string(), + serde_json::to_string(&candidate_pair_state).unwrap() + ); + } +} + +#[test] +fn test_candidate_pair_state_to_string() { + let tests = vec![ + (CandidatePairState::Unspecified, "unspecified"), + (CandidatePairState::Waiting, "waiting"), + (CandidatePairState::InProgress, "in-progress"), + (CandidatePairState::Failed, "failed"), + (CandidatePairState::Succeeded, "succeeded"), + ]; + + for (candidate_pair_state, expected_string) in tests { + assert_eq!(candidate_pair_state.to_string(), expected_string); + } +} + +#[test] +fn test_candidate_type_serialization() { + let tests = vec![ + (CandidateType::Unspecified, "\"unspecified\""), + (CandidateType::Host, "\"host\""), + (CandidateType::ServerReflexive, "\"srflx\""), + (CandidateType::PeerReflexive, "\"prflx\""), + (CandidateType::Relay, "\"relay\""), + ]; + + for (candidate_type, expected_string) in tests { + assert_eq!( + serde_json::to_string(&candidate_type).unwrap(), + expected_string.to_string() + ); + } +} + +#[test] +fn test_candidate_type_to_string() { + let tests = vec![ + (CandidateType::Unspecified, "Unknown candidate type"), + (CandidateType::Host, "host"), + (CandidateType::ServerReflexive, "srflx"), + (CandidateType::PeerReflexive, "prflx"), + (CandidateType::Relay, "relay"), + ]; + + for (candidate_type, expected_string) in tests { + assert_eq!(candidate_type.to_string(), expected_string); + } +} + +#[test] +fn test_candidate_marshal() -> Result<()> { + let tests = vec![ + ( + Some(CandidateBase{ + network_type: AtomicU8::new(NetworkType::Udp6 as u8), + candidate_type: CandidateType::Host, + address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a".to_owned(), + port: 53987, + priority_override: 500, + foundation_override: "750".to_owned(), + ..Default::default() + }), + "750 1 udp 500 fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a 53987 typ host", + ), + ( + Some(CandidateBase{ + network_type: AtomicU8::new(NetworkType::Udp4 as u8), + candidate_type: CandidateType::Host, + address: "10.0.75.1".to_owned(), + port: 53634, + ..Default::default() + }), + "4273957277 1 udp 2130706431 10.0.75.1 53634 typ host", + ), + ( + Some(CandidateBase{ + network_type: AtomicU8::new(NetworkType::Udp4 as u8), + candidate_type: CandidateType::ServerReflexive, + address: "191.228.238.68".to_owned(), + port: 53991, + related_address: Some(CandidateRelatedAddress{ + address: "192.168.0.274".to_owned(), + port:53991 + }), + ..Default::default() + }), + "647372371 1 udp 1694498815 191.228.238.68 53991 typ srflx raddr 192.168.0.274 rport 53991", + ), + ( + Some(CandidateBase{ + network_type: AtomicU8::new(NetworkType::Udp4 as u8), + candidate_type: CandidateType::Relay, + address: "50.0.0.1".to_owned(), + port: 5000, + related_address: Some( + CandidateRelatedAddress{ + address: "192.168.0.1".to_owned(), + port:5001} + ), + ..Default::default() + }), + "848194626 1 udp 16777215 50.0.0.1 5000 typ relay raddr 192.168.0.1 rport 5001", + ), + ( + Some(CandidateBase{ + network_type: AtomicU8::new(NetworkType::Tcp4 as u8), + candidate_type: CandidateType::Host, + address: "192.168.0.196".to_owned(), + port: 0, + tcp_type: TcpType::Active, + ..Default::default() + }), + "1052353102 1 tcp 2128609279 192.168.0.196 0 typ host tcptype active", + ), + ( + Some(CandidateBase{ + network_type: AtomicU8::new(NetworkType::Udp4 as u8), + candidate_type: CandidateType::Host, + address: "e2494022-4d9a-4c1e-a750-cc48d4f8d6ee.local".to_owned(), + port: 60542, + ..Default::default() + }), + "1380287402 1 udp 2130706431 e2494022-4d9a-4c1e-a750-cc48d4f8d6ee.local 60542 typ host", + ), + // Invalid candidates + (None, ""), + (None, "1938809241"), + (None, "1986380506 99999999 udp 2122063615 10.0.75.1 53634 typ host generation 0 network-id 2"), + (None, "1986380506 1 udp 99999999999 10.0.75.1 53634 typ host"), + (None, "4207374051 1 udp 1685790463 191.228.238.68 99999999 typ srflx raddr 192.168.0.278 rport 53991 generation 0 network-id 3"), + (None, "4207374051 1 udp 1685790463 191.228.238.68 53991 typ srflx raddr"), + (None, "4207374051 1 udp 1685790463 191.228.238.68 53991 typ srflx raddr 192.168.0.278 rport 99999999 generation 0 network-id 3"), + (None, "4207374051 INVALID udp 2130706431 10.0.75.1 53634 typ host"), + (None, "4207374051 1 udp INVALID 10.0.75.1 53634 typ host"), + (None, "4207374051 INVALID udp 2130706431 10.0.75.1 INVALID typ host"), + (None, "4207374051 1 udp 2130706431 10.0.75.1 53634 typ INVALID"), + ]; + + for (candidate, marshaled) in tests { + let actual_candidate = unmarshal_candidate(marshaled); + if let Some(candidate) = candidate { + if let Ok(actual_candidate) = actual_candidate { + assert!( + candidate.equal(&actual_candidate), + "{} vs {}", + candidate.marshal(), + marshaled + ); + assert_eq!(marshaled, actual_candidate.marshal()); + } else { + panic!("expected ok"); + } + } else { + assert!(actual_candidate.is_err(), "expected error"); + } + } + + Ok(()) +} diff --git a/reserved/ice/src/candidate/mod.rs b/reserved/ice/src/candidate/mod.rs new file mode 100644 index 0000000..8c43c35 --- /dev/null +++ b/reserved/ice/src/candidate/mod.rs @@ -0,0 +1,324 @@ +#[cfg(test)] +mod candidate_pair_test; +#[cfg(test)] +mod candidate_relay_test; +#[cfg(test)] +mod candidate_server_reflexive_test; +#[cfg(test)] +mod candidate_test; + +pub mod candidate_base; +pub mod candidate_host; +pub mod candidate_peer_reflexive; +pub mod candidate_relay; +pub mod candidate_server_reflexive; + +use std::fmt; +use std::net::{IpAddr, SocketAddr}; +use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU8, Ordering}; +use std::sync::Arc; +use std::time::SystemTime; + +use async_trait::async_trait; +use candidate_base::*; +use serde::Serialize; +use tokio::sync::{broadcast, Mutex}; + +use crate::error::Result; +use crate::network_type::*; +use crate::tcp_type::*; + +pub(crate) const RECEIVE_MTU: usize = 8192; +pub(crate) const DEFAULT_LOCAL_PREFERENCE: u16 = 65535; + +/// Indicates that the candidate is used for RTP. +pub(crate) const COMPONENT_RTP: u16 = 1; +/// Indicates that the candidate is used for RTCP. +pub(crate) const COMPONENT_RTCP: u16 = 0; + +/// Candidate represents an ICE candidate +#[async_trait] +pub trait Candidate: fmt::Display { + /// An arbitrary string used in the freezing algorithm to + /// group similar candidates. It is the same for two candidates that + /// have the same type, base IP address, protocol (UDP, TCP, etc.), + /// and STUN or TURN server. + fn foundation(&self) -> String; + + /// A unique identifier for just this candidate + /// Unlike the foundation this is different for each candidate. + fn id(&self) -> String; + + /// A component is a piece of a data stream. + /// An example is one for RTP, and one for RTCP + fn component(&self) -> u16; + fn set_component(&self, c: u16); + + /// The last time this candidate received traffic + fn last_received(&self) -> SystemTime; + + /// The last time this candidate sent traffic + fn last_sent(&self) -> SystemTime; + + fn network_type(&self) -> NetworkType; + fn address(&self) -> String; + fn port(&self) -> u16; + + fn priority(&self) -> u32; + + /// A transport address related to candidate, + /// which is useful for diagnostics and other purposes. + fn related_address(&self) -> Option; + + fn candidate_type(&self) -> CandidateType; + fn tcp_type(&self) -> TcpType; + + fn marshal(&self) -> String; + + fn addr(&self) -> SocketAddr; + + async fn close(&self) -> Result<()>; + fn seen(&self, outbound: bool); + + async fn write_to(&self, raw: &[u8], dst: &(dyn Candidate + Send + Sync)) -> Result; + fn equal(&self, other: &dyn Candidate) -> bool; + fn set_ip(&self, ip: &IpAddr) -> Result<()>; + fn get_conn(&self) -> Option<&Arc>; + fn get_closed_ch(&self) -> Arc>>>; +} + +/// Represents the type of candidate `CandidateType` enum. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize)] +pub enum CandidateType { + #[serde(rename = "unspecified")] + Unspecified, + #[serde(rename = "host")] + Host, + #[serde(rename = "srflx")] + ServerReflexive, + #[serde(rename = "prflx")] + PeerReflexive, + #[serde(rename = "relay")] + Relay, +} + +// String makes CandidateType printable +impl fmt::Display for CandidateType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + CandidateType::Host => "host", + CandidateType::ServerReflexive => "srflx", + CandidateType::PeerReflexive => "prflx", + CandidateType::Relay => "relay", + CandidateType::Unspecified => "Unknown candidate type", + }; + write!(f, "{s}") + } +} + +impl Default for CandidateType { + fn default() -> Self { + Self::Unspecified + } +} + +impl CandidateType { + /// Returns the preference weight of a `CandidateType`. + /// + /// 4.1.2.2. Guidelines for Choosing Type and Local Preferences + /// The RECOMMENDED values are 126 for host candidates, 100 + /// for server reflexive candidates, 110 for peer reflexive candidates, + /// and 0 for relayed candidates. + #[must_use] + pub const fn preference(self) -> u16 { + match self { + Self::Host => 126, + Self::PeerReflexive => 110, + Self::ServerReflexive => 100, + Self::Relay | CandidateType::Unspecified => 0, + } + } +} + +pub(crate) fn contains_candidate_type( + candidate_type: CandidateType, + candidate_type_list: &[CandidateType], +) -> bool { + if candidate_type_list.is_empty() { + return false; + } + for ct in candidate_type_list { + if *ct == candidate_type { + return true; + } + } + false +} + +/// Convey transport addresses related to the candidate, useful for diagnostics and other purposes. +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct CandidateRelatedAddress { + pub address: String, + pub port: u16, +} + +// String makes CandidateRelatedAddress printable +impl fmt::Display for CandidateRelatedAddress { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, " related {}:{}", self.address, self.port) + } +} + +/// Represent the ICE candidate pair state. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize)] +pub enum CandidatePairState { + #[serde(rename = "unspecified")] + Unspecified = 0, + + /// Means a check has not been performed for this pair. + #[serde(rename = "waiting")] + Waiting = 1, + + /// Means a check has been sent for this pair, but the transaction is in progress. + #[serde(rename = "in-progress")] + InProgress = 2, + + /// Means a check for this pair was already done and failed, either never producing any response + /// or producing an unrecoverable failure response. + #[serde(rename = "failed")] + Failed = 3, + + /// Means a check for this pair was already done and produced a successful result. + #[serde(rename = "succeeded")] + Succeeded = 4, +} + +impl From for CandidatePairState { + fn from(v: u8) -> Self { + match v { + 1 => Self::Waiting, + 2 => Self::InProgress, + 3 => Self::Failed, + 4 => Self::Succeeded, + _ => Self::Unspecified, + } + } +} + +impl Default for CandidatePairState { + fn default() -> Self { + Self::Unspecified + } +} + +impl fmt::Display for CandidatePairState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + Self::Waiting => "waiting", + Self::InProgress => "in-progress", + Self::Failed => "failed", + Self::Succeeded => "succeeded", + Self::Unspecified => "unspecified", + }; + + write!(f, "{s}") + } +} + +/// Represents a combination of a local and remote candidate. +pub struct CandidatePair { + pub(crate) ice_role_controlling: AtomicBool, + pub remote: Arc, + pub local: Arc, + pub(crate) binding_request_count: AtomicU16, + pub(crate) state: AtomicU8, // convert it to CandidatePairState, + pub(crate) nominated: AtomicBool, +} + +impl Default for CandidatePair { + fn default() -> Self { + Self { + ice_role_controlling: AtomicBool::new(false), + remote: Arc::new(CandidateBase::default()), + local: Arc::new(CandidateBase::default()), + state: AtomicU8::new(CandidatePairState::Waiting as u8), + binding_request_count: AtomicU16::new(0), + nominated: AtomicBool::new(false), + } + } +} + +impl fmt::Debug for CandidatePair { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "prio {} (local, prio {}) {} <-> {} (remote, prio {})", + self.priority(), + self.local.priority(), + self.local, + self.remote, + self.remote.priority() + ) + } +} + +impl fmt::Display for CandidatePair { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "prio {} (local, prio {}) {} <-> {} (remote, prio {})", + self.priority(), + self.local.priority(), + self.local, + self.remote, + self.remote.priority() + ) + } +} + +impl PartialEq for CandidatePair { + fn eq(&self, other: &Self) -> bool { + self.local.equal(&*other.local) && self.remote.equal(&*other.remote) + } +} + +impl CandidatePair { + #[must_use] + pub fn new( + local: Arc, + remote: Arc, + controlling: bool, + ) -> Self { + Self { + ice_role_controlling: AtomicBool::new(controlling), + remote, + local, + state: AtomicU8::new(CandidatePairState::Waiting as u8), + binding_request_count: AtomicU16::new(0), + nominated: AtomicBool::new(false), + } + } + + /// RFC 5245 - 5.7.2. Computing Pair Priority and Ordering Pairs + /// Let G be the priority for the candidate provided by the controlling + /// agent. Let D be the priority for the candidate provided by the + /// controlled agent. + /// pair priority = 2^32*MIN(G,D) + 2*MAX(G,D) + (G>D?1:0) + pub fn priority(&self) -> u64 { + let (g, d) = if self.ice_role_controlling.load(Ordering::SeqCst) { + (self.local.priority(), self.remote.priority()) + } else { + (self.remote.priority(), self.local.priority()) + }; + + // 1<<32 overflows uint32; and if both g && d are + // maxUint32, this result would overflow uint64 + ((1 << 32_u64) - 1) * u64::from(std::cmp::min(g, d)) + + 2 * u64::from(std::cmp::max(g, d)) + + u64::from(g > d) + } + + pub async fn write(&self, b: &[u8]) -> Result { + self.local.write_to(b, &*self.remote).await + } +} diff --git a/reserved/ice/src/control/control_test.rs b/reserved/ice/src/control/control_test.rs new file mode 100644 index 0000000..480c045 --- /dev/null +++ b/reserved/ice/src/control/control_test.rs @@ -0,0 +1,168 @@ +use super::*; +use crate::error::Result; + +#[test] +fn test_controlled_get_from() -> Result<()> { + let mut m = Message::new(); + let mut c = AttrControlled(4321); + let result = c.get_from(&m); + if let Err(err) = result { + assert_eq!(stun::Error::ErrAttributeNotFound, err, "unexpected error"); + } else { + panic!("expected error, but got ok"); + } + + m.build(&[Box::new(BINDING_REQUEST), Box::new(c)])?; + + let mut m1 = Message::new(); + m1.write(&m.raw)?; + + let mut c1 = AttrControlled::default(); + c1.get_from(&m1)?; + + assert_eq!(c1, c, "not equal"); + + //"IncorrectSize" + { + let mut m3 = Message::new(); + m3.add(ATTR_ICE_CONTROLLED, &[0; 100]); + let mut c2 = AttrControlled::default(); + let result = c2.get_from(&m3); + if let Err(err) = result { + assert!(is_attr_size_invalid(&err), "should error"); + } else { + panic!("expected error, but got ok"); + } + } + + Ok(()) +} + +#[test] +fn test_controlling_get_from() -> Result<()> { + let mut m = Message::new(); + let mut c = AttrControlling(4321); + let result = c.get_from(&m); + if let Err(err) = result { + assert_eq!(stun::Error::ErrAttributeNotFound, err, "unexpected error"); + } else { + panic!("expected error, but got ok"); + } + + m.build(&[Box::new(BINDING_REQUEST), Box::new(c)])?; + + let mut m1 = Message::new(); + m1.write(&m.raw)?; + + let mut c1 = AttrControlling::default(); + c1.get_from(&m1)?; + + assert_eq!(c1, c, "not equal"); + + //"IncorrectSize" + { + let mut m3 = Message::new(); + m3.add(ATTR_ICE_CONTROLLING, &[0; 100]); + let mut c2 = AttrControlling::default(); + let result = c2.get_from(&m3); + if let Err(err) = result { + assert!(is_attr_size_invalid(&err), "should error"); + } else { + panic!("expected error, but got ok"); + } + } + + Ok(()) +} + +#[test] +fn test_control_get_from() -> Result<()> { + //"Blank" + { + let m = Message::new(); + let mut c = AttrControl::default(); + let result = c.get_from(&m); + if let Err(err) = result { + assert_eq!(stun::Error::ErrAttributeNotFound, err, "unexpected error"); + } else { + panic!("expected error, but got ok"); + } + } + //"Controlling" + { + let mut m = Message::new(); + let mut c = AttrControl::default(); + let result = c.get_from(&m); + if let Err(err) = result { + assert_eq!(stun::Error::ErrAttributeNotFound, err, "unexpected error"); + } else { + panic!("expected error, but got ok"); + } + + c.role = Role::Controlling; + c.tie_breaker = TieBreaker(4321); + + m.build(&[Box::new(BINDING_REQUEST), Box::new(c)])?; + + let mut m1 = Message::new(); + m1.write(&m.raw)?; + + let mut c1 = AttrControl::default(); + c1.get_from(&m1)?; + + assert_eq!(c1, c, "not equal"); + + //"IncorrectSize" + { + let mut m3 = Message::new(); + m3.add(ATTR_ICE_CONTROLLING, &[0; 100]); + let mut c2 = AttrControl::default(); + let result = c2.get_from(&m3); + if let Err(err) = result { + assert!(is_attr_size_invalid(&err), "should error"); + } else { + panic!("expected error, but got ok"); + } + } + } + + //"Controlled" + { + let mut m = Message::new(); + let mut c = AttrControl::default(); + let result = c.get_from(&m); + if let Err(err) = result { + assert_eq!(stun::Error::ErrAttributeNotFound, err, "unexpected error"); + } else { + panic!("expected error, but got ok"); + } + + c.role = Role::Controlled; + c.tie_breaker = TieBreaker(1234); + + m.build(&[Box::new(BINDING_REQUEST), Box::new(c)])?; + + let mut m1 = Message::new(); + m1.write(&m.raw)?; + + let mut c1 = AttrControl::default(); + c1.get_from(&m1)?; + + assert_eq!(c1, c, "not equal"); + + //"IncorrectSize" + { + let mut m3 = Message::new(); + m3.add(ATTR_ICE_CONTROLLING, &[0; 100]); + let mut c2 = AttrControl::default(); + let result = c2.get_from(&m3); + if let Err(err) = result { + assert!(is_attr_size_invalid(&err), "should error"); + } else { + panic!("expected error, but got ok"); + } + } + } + + Ok(()) +} diff --git a/reserved/ice/src/control/mod.rs b/reserved/ice/src/control/mod.rs new file mode 100644 index 0000000..a79e170 --- /dev/null +++ b/reserved/ice/src/control/mod.rs @@ -0,0 +1,143 @@ +#[cfg(test)] +mod control_test; + +use std::fmt; + +use stun::attributes::*; +use stun::checks::*; +use stun::message::*; + +/// Common helper for ICE-{CONTROLLED,CONTROLLING} and represents the so-called Tiebreaker number. +#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)] +pub struct TieBreaker(pub u64); + +pub(crate) const TIE_BREAKER_SIZE: usize = 8; // 64 bit + +impl TieBreaker { + /// Adds Tiebreaker value to m as t attribute. + pub fn add_to_as(self, m: &mut Message, t: AttrType) -> Result<(), stun::Error> { + let mut v = vec![0; TIE_BREAKER_SIZE]; + v.copy_from_slice(&self.0.to_be_bytes()); + m.add(t, &v); + Ok(()) + } + + /// Decodes Tiebreaker value in message getting it as for t type. + pub fn get_from_as(&mut self, m: &Message, t: AttrType) -> Result<(), stun::Error> { + let v = m.get(t)?; + check_size(t, v.len(), TIE_BREAKER_SIZE)?; + self.0 = u64::from_be_bytes([v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]]); + Ok(()) + } +} +/// Represents ICE-CONTROLLED attribute. +#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)] +pub struct AttrControlled(pub u64); + +impl Setter for AttrControlled { + /// Adds ICE-CONTROLLED to message. + fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { + TieBreaker(self.0).add_to_as(m, ATTR_ICE_CONTROLLED) + } +} + +impl Getter for AttrControlled { + /// Decodes ICE-CONTROLLED from message. + fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { + let mut t = TieBreaker::default(); + t.get_from_as(m, ATTR_ICE_CONTROLLED)?; + self.0 = t.0; + Ok(()) + } +} + +/// Represents ICE-CONTROLLING attribute. +#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)] +pub struct AttrControlling(pub u64); + +impl Setter for AttrControlling { + // add_to adds ICE-CONTROLLING to message. + fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { + TieBreaker(self.0).add_to_as(m, ATTR_ICE_CONTROLLING) + } +} + +impl Getter for AttrControlling { + // get_from decodes ICE-CONTROLLING from message. + fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { + let mut t = TieBreaker::default(); + t.get_from_as(m, ATTR_ICE_CONTROLLING)?; + self.0 = t.0; + Ok(()) + } +} + +/// Helper that wraps ICE-{CONTROLLED,CONTROLLING}. +#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)] +pub struct AttrControl { + role: Role, + tie_breaker: TieBreaker, +} + +impl Setter for AttrControl { + // add_to adds ICE-CONTROLLED or ICE-CONTROLLING attribute depending on Role. + fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { + if self.role == Role::Controlling { + self.tie_breaker.add_to_as(m, ATTR_ICE_CONTROLLING) + } else { + self.tie_breaker.add_to_as(m, ATTR_ICE_CONTROLLED) + } + } +} + +impl Getter for AttrControl { + // get_from decodes Role and Tiebreaker value from message. + fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { + if m.contains(ATTR_ICE_CONTROLLING) { + self.role = Role::Controlling; + return self.tie_breaker.get_from_as(m, ATTR_ICE_CONTROLLING); + } + if m.contains(ATTR_ICE_CONTROLLED) { + self.role = Role::Controlled; + return self.tie_breaker.get_from_as(m, ATTR_ICE_CONTROLLED); + } + + Err(stun::Error::ErrAttributeNotFound) + } +} + +/// Represents ICE agent role, which can be controlling or controlled. +/// Possible ICE agent roles. +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub enum Role { + Controlling, + Controlled, + Unspecified, +} + +impl Default for Role { + fn default() -> Self { + Self::Controlling + } +} + +impl From<&str> for Role { + fn from(raw: &str) -> Self { + match raw { + "controlling" => Self::Controlling, + "controlled" => Self::Controlled, + _ => Self::Unspecified, + } + } +} + +impl fmt::Display for Role { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + Self::Controlling => "controlling", + Self::Controlled => "controlled", + Self::Unspecified => "unspecified", + }; + write!(f, "{s}") + } +} diff --git a/reserved/ice/src/error.rs b/reserved/ice/src/error.rs new file mode 100644 index 0000000..a3f6ff8 --- /dev/null +++ b/reserved/ice/src/error.rs @@ -0,0 +1,238 @@ +use std::num::ParseIntError; +use std::time::SystemTimeError; +use std::{io, net}; + +use thiserror::Error; + +pub type Result = std::result::Result; + +#[derive(Debug, Error, PartialEq)] +#[non_exhaustive] +pub enum Error { + /// Indicates an error with Unknown info. + #[error("Unknown type")] + ErrUnknownType, + + /// Indicates the scheme type could not be parsed. + #[error("unknown scheme type")] + ErrSchemeType, + + /// Indicates query arguments are provided in a STUN URL. + #[error("queries not supported in stun address")] + ErrStunQuery, + + /// Indicates an malformed query is provided. + #[error("invalid query")] + ErrInvalidQuery, + + /// Indicates malformed hostname is provided. + #[error("invalid hostname")] + ErrHost, + + /// Indicates malformed port is provided. + #[error("invalid port number")] + ErrPort, + + /// Indicates local username fragment insufficient bits are provided. + /// Have to be at least 24 bits long. + #[error("local username fragment is less than 24 bits long")] + ErrLocalUfragInsufficientBits, + + /// Indicates local passoword insufficient bits are provided. + /// Have to be at least 128 bits long. + #[error("local password is less than 128 bits long")] + ErrLocalPwdInsufficientBits, + + /// Indicates an unsupported transport type was provided. + #[error("invalid transport protocol type")] + ErrProtoType, + + /// Indicates the agent is closed. + #[error("the agent is closed")] + ErrClosed, + + /// Indicates agent does not have a valid candidate pair. + #[error("no candidate pairs available")] + ErrNoCandidatePairs, + + /// Indicates agent connection was canceled by the caller. + #[error("connecting canceled by caller")] + ErrCanceledByCaller, + + /// Indicates agent was started twice. + #[error("attempted to start agent twice")] + ErrMultipleStart, + + /// Indicates agent was started with an empty remote ufrag. + #[error("remote ufrag is empty")] + ErrRemoteUfragEmpty, + + /// Indicates agent was started with an empty remote pwd. + #[error("remote pwd is empty")] + ErrRemotePwdEmpty, + + /// Indicates agent was started without on_candidate. + #[error("no on_candidate provided")] + ErrNoOnCandidateHandler, + + /// Indicates GatherCandidates has been called multiple times. + #[error("attempting to gather candidates during gathering state")] + ErrMultipleGatherAttempted, + + /// Indicates agent was give TURN URL with an empty Username. + #[error("username is empty")] + ErrUsernameEmpty, + + /// Indicates agent was give TURN URL with an empty Password. + #[error("password is empty")] + ErrPasswordEmpty, + + /// Indicates we were unable to parse a candidate address. + #[error("failed to parse address")] + ErrAddressParseFailed, + + /// Indicates that non host candidates were selected for a lite agent. + #[error("lite agents must only use host candidates")] + ErrLiteUsingNonHostCandidates, + + /// Indicates that one or more URL was provided to the agent but no host candidate required them. + #[error("agent does not need URL with selected candidate types")] + ErrUselessUrlsProvided, + + /// Indicates that the specified NAT1To1IPCandidateType is unsupported. + #[error("unsupported 1:1 NAT IP candidate type")] + ErrUnsupportedNat1to1IpCandidateType, + + /// Indicates that the given 1:1 NAT IP mapping is invalid. + #[error("invalid 1:1 NAT IP mapping")] + ErrInvalidNat1to1IpMapping, + + /// IPNotFound in NAT1To1IPMapping. + #[error("external mapped IP not found")] + ErrExternalMappedIpNotFound, + + /// Indicates that the mDNS gathering cannot be used along with 1:1 NAT IP mapping for host + /// candidate. + #[error("mDNS gathering cannot be used with 1:1 NAT IP mapping for host candidate")] + ErrMulticastDnsWithNat1to1IpMapping, + + /// Indicates that 1:1 NAT IP mapping for host candidate is requested, but the host candidate + /// type is disabled. + #[error("1:1 NAT IP mapping for host candidate ineffective")] + ErrIneffectiveNat1to1IpMappingHost, + + /// Indicates that 1:1 NAT IP mapping for srflx candidate is requested, but the srflx candidate + /// type is disabled. + #[error("1:1 NAT IP mapping for srflx candidate ineffective")] + ErrIneffectiveNat1to1IpMappingSrflx, + + /// Indicates an invalid MulticastDNSHostName. + #[error("invalid mDNS HostName, must end with .local and can only contain a single '.'")] + ErrInvalidMulticastDnshostName, + + /// Indicates Restart was called when Agent is in GatheringStateGathering. + #[error("ICE Agent can not be restarted when gathering")] + ErrRestartWhenGathering, + + /// Indicates a run operation was canceled by its individual done. + #[error("run was canceled by done")] + ErrRunCanceled, + + /// Initialized Indicates TCPMux is not initialized and that invalidTCPMux is used. + #[error("TCPMux is not initialized")] + ErrTcpMuxNotInitialized, + + /// Indicates we already have the connection with same remote addr. + #[error("conn with same remote addr already exists")] + ErrTcpRemoteAddrAlreadyExists, + + #[error("failed to send packet")] + ErrSendPacket, + #[error("attribute not long enough to be ICE candidate")] + ErrAttributeTooShortIceCandidate, + #[error("could not parse component")] + ErrParseComponent, + #[error("could not parse priority")] + ErrParsePriority, + #[error("could not parse port")] + ErrParsePort, + #[error("could not parse related addresses")] + ErrParseRelatedAddr, + #[error("could not parse type")] + ErrParseType, + #[error("unknown candidate type")] + ErrUnknownCandidateType, + #[error("failed to get XOR-MAPPED-ADDRESS response")] + ErrGetXorMappedAddrResponse, + #[error("connection with same remote address already exists")] + ErrConnectionAddrAlreadyExist, + #[error("error reading streaming packet")] + ErrReadingStreamingPacket, + #[error("error writing to")] + ErrWriting, + #[error("error closing connection")] + ErrClosingConnection, + #[error("unable to determine networkType")] + ErrDetermineNetworkType, + #[error("missing protocol scheme")] + ErrMissingProtocolScheme, + #[error("too many colons in address")] + ErrTooManyColonsAddr, + #[error("unexpected error trying to read")] + ErrRead, + #[error("unknown role")] + ErrUnknownRole, + #[error("username mismatch")] + ErrMismatchUsername, + #[error("the ICE conn can't write STUN messages")] + ErrIceWriteStunMessage, + #[error("invalid url")] + ErrInvalidUrl, + #[error("relative URL without a base")] + ErrUrlParse, + #[error("Candidate IP could not be found")] + ErrCandidateIpNotFound, + + #[error("parse int: {0}")] + ParseInt(#[from] ParseIntError), + #[error("parse addr: {0}")] + ParseIp(#[from] net::AddrParseError), + #[error("{0}")] + Io(#[source] IoError), + #[error("{0}")] + Util(#[from] util::Error), + #[error("{0}")] + Stun(#[from] stun::Error), + #[error("{0}")] + ParseUrl(#[from] url::ParseError), + #[error("{0}")] + Mdns(#[from] mdns::Error), + #[error("{0}")] + Turn(#[from] turn::Error), + + #[error("{0}")] + Other(String), +} + +#[derive(Debug, Error)] +#[error("io error: {0}")] +pub struct IoError(#[from] pub io::Error); + +// Workaround for wanting PartialEq for io::Error. +impl PartialEq for IoError { + fn eq(&self, other: &Self) -> bool { + self.0.kind() == other.0.kind() + } +} + +impl From for Error { + fn from(e: io::Error) -> Self { + Error::Io(IoError(e)) + } +} + +impl From for Error { + fn from(e: SystemTimeError) -> Self { + Error::Other(e.to_string()) + } +} diff --git a/reserved/ice/src/external_ip_mapper/external_ip_mapper_test.rs b/reserved/ice/src/external_ip_mapper/external_ip_mapper_test.rs new file mode 100644 index 0000000..cb04d08 --- /dev/null +++ b/reserved/ice/src/external_ip_mapper/external_ip_mapper_test.rs @@ -0,0 +1,251 @@ +use super::*; + +#[test] +fn test_external_ip_mapper_validate_ip_string() -> Result<()> { + let ip = validate_ip_string("1.2.3.4")?; + assert!(ip.is_ipv4(), "should be true"); + assert_eq!("1.2.3.4", ip.to_string(), "should be true"); + + let ip = validate_ip_string("2601:4567::5678")?; + assert!(!ip.is_ipv4(), "should be false"); + assert_eq!("2601:4567::5678", ip.to_string(), "should be true"); + + let result = validate_ip_string("bad.6.6.6"); + assert!(result.is_err(), "should fail"); + + Ok(()) +} + +#[test] +fn test_external_ip_mapper_new_external_ip_mapper() -> Result<()> { + // ips being empty should succeed but mapper will still be nil + let m = ExternalIpMapper::new(CandidateType::Unspecified, &[])?; + assert!(m.is_none(), "should be none"); + + // IPv4 with no explicit local IP, defaults to CandidateTypeHost + let m = ExternalIpMapper::new(CandidateType::Unspecified, &["1.2.3.4".to_owned()])?.unwrap(); + assert_eq!(m.candidate_type, CandidateType::Host, "should match"); + assert!(m.ipv4_mapping.ip_sole.is_some()); + assert!(m.ipv6_mapping.ip_sole.is_none()); + assert_eq!(m.ipv4_mapping.ip_map.len(), 0, "should match"); + assert_eq!(m.ipv6_mapping.ip_map.len(), 0, "should match"); + + // IPv4 with no explicit local IP, using CandidateTypeServerReflexive + let m = + ExternalIpMapper::new(CandidateType::ServerReflexive, &["1.2.3.4".to_owned()])?.unwrap(); + assert_eq!( + CandidateType::ServerReflexive, + m.candidate_type, + "should match" + ); + assert!(m.ipv4_mapping.ip_sole.is_some()); + assert!(m.ipv6_mapping.ip_sole.is_none()); + assert_eq!(m.ipv4_mapping.ip_map.len(), 0, "should match"); + assert_eq!(m.ipv6_mapping.ip_map.len(), 0, "should match"); + + // IPv4 with no explicit local IP, defaults to CandidateTypeHost + let m = ExternalIpMapper::new(CandidateType::Unspecified, &["2601:4567::5678".to_owned()])? + .unwrap(); + assert_eq!(m.candidate_type, CandidateType::Host, "should match"); + assert!(m.ipv4_mapping.ip_sole.is_none()); + assert!(m.ipv6_mapping.ip_sole.is_some()); + assert_eq!(m.ipv4_mapping.ip_map.len(), 0, "should match"); + assert_eq!(m.ipv6_mapping.ip_map.len(), 0, "should match"); + + // IPv4 and IPv6 in the mix + let m = ExternalIpMapper::new( + CandidateType::Unspecified, + &["1.2.3.4".to_owned(), "2601:4567::5678".to_owned()], + )? + .unwrap(); + assert_eq!(m.candidate_type, CandidateType::Host, "should match"); + assert!(m.ipv4_mapping.ip_sole.is_some()); + assert!(m.ipv6_mapping.ip_sole.is_some()); + assert_eq!(m.ipv4_mapping.ip_map.len(), 0, "should match"); + assert_eq!(m.ipv6_mapping.ip_map.len(), 0, "should match"); + + // Unsupported candidate type - CandidateTypePeerReflexive + let result = ExternalIpMapper::new(CandidateType::PeerReflexive, &["1.2.3.4".to_owned()]); + assert!(result.is_err(), "should fail"); + + // Unsupported candidate type - CandidateTypeRelay + let result = ExternalIpMapper::new(CandidateType::PeerReflexive, &["1.2.3.4".to_owned()]); + assert!(result.is_err(), "should fail"); + + // Cannot duplicate mapping IPv4 family + let result = ExternalIpMapper::new( + CandidateType::ServerReflexive, + &["1.2.3.4".to_owned(), "5.6.7.8".to_owned()], + ); + assert!(result.is_err(), "should fail"); + + // Cannot duplicate mapping IPv6 family + let result = ExternalIpMapper::new( + CandidateType::ServerReflexive, + &["2201::1".to_owned(), "2201::0002".to_owned()], + ); + assert!(result.is_err(), "should fail"); + + // Invalide external IP string + let result = ExternalIpMapper::new(CandidateType::ServerReflexive, &["bad.2.3.4".to_owned()]); + assert!(result.is_err(), "should fail"); + + // Invalide local IP string + let result = ExternalIpMapper::new( + CandidateType::ServerReflexive, + &["1.2.3.4/10.0.0.bad".to_owned()], + ); + assert!(result.is_err(), "should fail"); + + Ok(()) +} + +#[test] +fn test_external_ip_mapper_new_external_ip_mapper_with_explicit_local_ip() -> Result<()> { + // IPv4 with explicit local IP, defaults to CandidateTypeHost + let m = ExternalIpMapper::new(CandidateType::Unspecified, &["1.2.3.4/10.0.0.1".to_owned()])? + .unwrap(); + assert_eq!(m.candidate_type, CandidateType::Host, "should match"); + assert!(m.ipv4_mapping.ip_sole.is_none()); + assert!(m.ipv6_mapping.ip_sole.is_none()); + assert_eq!(m.ipv4_mapping.ip_map.len(), 1, "should match"); + assert_eq!(m.ipv6_mapping.ip_map.len(), 0, "should match"); + + // Cannot assign two ext IPs for one local IPv4 + let result = ExternalIpMapper::new( + CandidateType::Unspecified, + &["1.2.3.4/10.0.0.1".to_owned(), "1.2.3.5/10.0.0.1".to_owned()], + ); + assert!(result.is_err(), "should fail"); + + // Cannot assign two ext IPs for one local IPv6 + let result = ExternalIpMapper::new( + CandidateType::Unspecified, + &[ + "2200::1/fe80::1".to_owned(), + "2200::0002/fe80::1".to_owned(), + ], + ); + assert!(result.is_err(), "should fail"); + + // Cannot mix different IP family in a pair (1) + let result = + ExternalIpMapper::new(CandidateType::Unspecified, &["2200::1/10.0.0.1".to_owned()]); + assert!(result.is_err(), "should fail"); + + // Cannot mix different IP family in a pair (2) + let result = ExternalIpMapper::new(CandidateType::Unspecified, &["1.2.3.4/fe80::1".to_owned()]); + assert!(result.is_err(), "should fail"); + + // Invalid pair + let result = ExternalIpMapper::new( + CandidateType::Unspecified, + &["1.2.3.4/192.168.0.2/10.0.0.1".to_owned()], + ); + assert!(result.is_err(), "should fail"); + + Ok(()) +} + +#[test] +fn test_external_ip_mapper_new_external_ip_mapper_with_implicit_local_ip() -> Result<()> { + // Mixing inpicit and explicit local IPs not allowed + let result = ExternalIpMapper::new( + CandidateType::Unspecified, + &["1.2.3.4".to_owned(), "1.2.3.5/10.0.0.1".to_owned()], + ); + assert!(result.is_err(), "should fail"); + + // Mixing inpicit and explicit local IPs not allowed + let result = ExternalIpMapper::new( + CandidateType::Unspecified, + &["1.2.3.5/10.0.0.1".to_owned(), "1.2.3.4".to_owned()], + ); + assert!(result.is_err(), "should fail"); + + Ok(()) +} + +#[test] +fn test_external_ip_mapper_find_external_ip_without_explicit_local_ip() -> Result<()> { + // IPv4 with explicit local IP, defaults to CandidateTypeHost + let m = ExternalIpMapper::new( + CandidateType::Unspecified, + &["1.2.3.4".to_owned(), "2200::1".to_owned()], + )? + .unwrap(); + assert!(m.ipv4_mapping.ip_sole.is_some()); + assert!(m.ipv6_mapping.ip_sole.is_some()); + + // find external IPv4 + let ext_ip = m.find_external_ip("10.0.0.1")?; + assert_eq!(ext_ip.to_string(), "1.2.3.4", "should match"); + + // find external IPv6 + let ext_ip = m.find_external_ip("fe80::0001")?; // use '0001' instead of '1' on purpse + assert_eq!(ext_ip.to_string(), "2200::1", "should match"); + + // Bad local IP string + let result = m.find_external_ip("really.bad"); + assert!(result.is_err(), "should fail"); + + Ok(()) +} + +#[test] +fn test_external_ip_mapper_find_external_ip_with_explicit_local_ip() -> Result<()> { + // IPv4 with explicit local IP, defaults to CandidateTypeHost + let m = ExternalIpMapper::new( + CandidateType::Unspecified, + &[ + "1.2.3.4/10.0.0.1".to_owned(), + "1.2.3.5/10.0.0.2".to_owned(), + "2200::1/fe80::1".to_owned(), + "2200::2/fe80::2".to_owned(), + ], + )? + .unwrap(); + + // find external IPv4 + let ext_ip = m.find_external_ip("10.0.0.1")?; + assert_eq!(ext_ip.to_string(), "1.2.3.4", "should match"); + + let ext_ip = m.find_external_ip("10.0.0.2")?; + assert_eq!(ext_ip.to_string(), "1.2.3.5", "should match"); + + let result = m.find_external_ip("10.0.0.3"); + assert!(result.is_err(), "should fail"); + + // find external IPv6 + let ext_ip = m.find_external_ip("fe80::0001")?; // use '0001' instead of '1' on purpse + assert_eq!(ext_ip.to_string(), "2200::1", "should match"); + + let ext_ip = m.find_external_ip("fe80::0002")?; // use '0002' instead of '2' on purpse + assert_eq!(ext_ip.to_string(), "2200::2", "should match"); + + let result = m.find_external_ip("fe80::3"); + assert!(result.is_err(), "should fail"); + + // Bad local IP string + let result = m.find_external_ip("really.bad"); + assert!(result.is_err(), "should fail"); + + Ok(()) +} + +#[test] +fn test_external_ip_mapper_find_external_ip_with_empty_map() -> Result<()> { + let m = ExternalIpMapper::new(CandidateType::Unspecified, &["1.2.3.4".to_owned()])?.unwrap(); + + // attempt to find IPv6 that does not exist in the map + let result = m.find_external_ip("fe80::1"); + assert!(result.is_err(), "should fail"); + + let m = ExternalIpMapper::new(CandidateType::Unspecified, &["2200::1".to_owned()])?.unwrap(); + + // attempt to find IPv4 that does not exist in the map + let result = m.find_external_ip("10.0.0.1"); + assert!(result.is_err(), "should fail"); + + Ok(()) +} diff --git a/reserved/ice/src/external_ip_mapper/mod.rs b/reserved/ice/src/external_ip_mapper/mod.rs new file mode 100644 index 0000000..0d968b8 --- /dev/null +++ b/reserved/ice/src/external_ip_mapper/mod.rs @@ -0,0 +1,133 @@ +#[cfg(test)] +mod external_ip_mapper_test; + +use std::collections::HashMap; +use std::net::IpAddr; + +use crate::candidate::*; +use crate::error::*; + +pub(crate) fn validate_ip_string(ip_str: &str) -> Result { + match ip_str.parse() { + Ok(ip) => Ok(ip), + Err(_) => Err(Error::ErrInvalidNat1to1IpMapping), + } +} + +/// Holds the mapping of local and external IP address for a particular IP family. +#[derive(Default, PartialEq, Debug)] +pub(crate) struct IpMapping { + ip_sole: Option, // when non-nil, this is the sole external IP for one local IP assumed + ip_map: HashMap, // local-to-external IP mapping (k: local, v: external) +} + +impl IpMapping { + pub(crate) fn set_sole_ip(&mut self, ip: IpAddr) -> Result<()> { + if self.ip_sole.is_some() || !self.ip_map.is_empty() { + return Err(Error::ErrInvalidNat1to1IpMapping); + } + + self.ip_sole = Some(ip); + + Ok(()) + } + + pub(crate) fn add_ip_mapping(&mut self, loc_ip: IpAddr, ext_ip: IpAddr) -> Result<()> { + if self.ip_sole.is_some() { + return Err(Error::ErrInvalidNat1to1IpMapping); + } + + let loc_ip_str = loc_ip.to_string(); + + // check if dup of local IP + if self.ip_map.contains_key(&loc_ip_str) { + return Err(Error::ErrInvalidNat1to1IpMapping); + } + + self.ip_map.insert(loc_ip_str, ext_ip); + + Ok(()) + } + + pub(crate) fn find_external_ip(&self, loc_ip: IpAddr) -> Result { + if let Some(ip_sole) = &self.ip_sole { + return Ok(*ip_sole); + } + + self.ip_map.get(&loc_ip.to_string()).map_or_else( + || Err(Error::ErrExternalMappedIpNotFound), + |ext_ip| Ok(*ext_ip), + ) + } +} + +#[derive(Default)] +pub(crate) struct ExternalIpMapper { + pub(crate) ipv4_mapping: IpMapping, + pub(crate) ipv6_mapping: IpMapping, + pub(crate) candidate_type: CandidateType, +} + +impl ExternalIpMapper { + pub(crate) fn new(mut candidate_type: CandidateType, ips: &[String]) -> Result> { + if ips.is_empty() { + return Ok(None); + } + if candidate_type == CandidateType::Unspecified { + candidate_type = CandidateType::Host; // defaults to host + } else if candidate_type != CandidateType::Host + && candidate_type != CandidateType::ServerReflexive + { + return Err(Error::ErrUnsupportedNat1to1IpCandidateType); + } + + let mut m = Self { + ipv4_mapping: IpMapping::default(), + ipv6_mapping: IpMapping::default(), + candidate_type, + }; + + for ext_ip_str in ips { + let ip_pair: Vec<&str> = ext_ip_str.split('/').collect(); + if ip_pair.is_empty() || ip_pair.len() > 2 { + return Err(Error::ErrInvalidNat1to1IpMapping); + } + + let ext_ip = validate_ip_string(ip_pair[0])?; + if ip_pair.len() == 1 { + if ext_ip.is_ipv4() { + m.ipv4_mapping.set_sole_ip(ext_ip)?; + } else { + m.ipv6_mapping.set_sole_ip(ext_ip)?; + } + } else { + let loc_ip = validate_ip_string(ip_pair[1])?; + if ext_ip.is_ipv4() { + if !loc_ip.is_ipv4() { + return Err(Error::ErrInvalidNat1to1IpMapping); + } + + m.ipv4_mapping.add_ip_mapping(loc_ip, ext_ip)?; + } else { + if loc_ip.is_ipv4() { + return Err(Error::ErrInvalidNat1to1IpMapping); + } + + m.ipv6_mapping.add_ip_mapping(loc_ip, ext_ip)?; + } + } + } + + Ok(Some(m)) + } + + pub(crate) fn find_external_ip(&self, local_ip_str: &str) -> Result { + let loc_ip = validate_ip_string(local_ip_str)?; + + if loc_ip.is_ipv4() { + self.ipv4_mapping.find_external_ip(loc_ip) + } else { + self.ipv6_mapping.find_external_ip(loc_ip) + } + } +} diff --git a/reserved/ice/src/lib.rs b/reserved/ice/src/lib.rs new file mode 100644 index 0000000..b2e0af1 --- /dev/null +++ b/reserved/ice/src/lib.rs @@ -0,0 +1,22 @@ +#![warn(rust_2018_idioms)] +#![allow(dead_code)] + +pub mod agent; +pub mod candidate; +pub mod control; +mod error; +pub mod external_ip_mapper; +pub mod mdns; +pub mod network_type; +pub mod priority; +pub mod rand; +pub mod state; +pub mod stats; +pub mod tcp_type; +pub mod udp_mux; +pub mod udp_network; +pub mod url; +pub mod use_candidate; +pub mod util; + +pub use error::Error; diff --git a/reserved/ice/src/mdns/mdns_test.rs b/reserved/ice/src/mdns/mdns_test.rs new file mode 100644 index 0000000..6040103 --- /dev/null +++ b/reserved/ice/src/mdns/mdns_test.rs @@ -0,0 +1,151 @@ +use regex::Regex; +use tokio::sync::{mpsc, Mutex}; + +use super::*; +use crate::agent::agent_config::*; +use crate::agent::agent_vnet_test::*; +use crate::agent::*; +use crate::candidate::*; +use crate::error::Error; +use crate::network_type::*; + +#[tokio::test] +// This test is disabled on Windows for now because it gets stuck and never finishes. +// This does not seem to have happened due to a code change. It started happening with +// `ce55c3a066ab461c3e74f0d5ac6f1209205e79bc` but was verified as happening on +// `92cc698a3dc6da459f3bf3789fd046c2dffdf107` too. +#[cfg(not(windows))] +async fn test_multicast_dns_only_connection() -> Result<()> { + let cfg0 = AgentConfig { + network_types: vec![NetworkType::Udp4], + candidate_types: vec![CandidateType::Host], + multicast_dns_mode: MulticastDnsMode::QueryAndGather, + ..Default::default() + }; + + let a_agent = Arc::new(Agent::new(cfg0).await?); + let (a_notifier, mut a_connected) = on_connected(); + a_agent.on_connection_state_change(a_notifier); + + let cfg1 = AgentConfig { + network_types: vec![NetworkType::Udp4], + candidate_types: vec![CandidateType::Host], + multicast_dns_mode: MulticastDnsMode::QueryAndGather, + ..Default::default() + }; + + let b_agent = Arc::new(Agent::new(cfg1).await?); + let (b_notifier, mut b_connected) = on_connected(); + b_agent.on_connection_state_change(b_notifier); + + connect_with_vnet(&a_agent, &b_agent).await?; + let _ = a_connected.recv().await; + let _ = b_connected.recv().await; + + a_agent.close().await?; + b_agent.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_multicast_dns_mixed_connection() -> Result<()> { + let cfg0 = AgentConfig { + network_types: vec![NetworkType::Udp4], + candidate_types: vec![CandidateType::Host], + multicast_dns_mode: MulticastDnsMode::QueryAndGather, + ..Default::default() + }; + + let a_agent = Arc::new(Agent::new(cfg0).await?); + let (a_notifier, mut a_connected) = on_connected(); + a_agent.on_connection_state_change(a_notifier); + + let cfg1 = AgentConfig { + network_types: vec![NetworkType::Udp4], + candidate_types: vec![CandidateType::Host], + multicast_dns_mode: MulticastDnsMode::QueryOnly, + ..Default::default() + }; + + let b_agent = Arc::new(Agent::new(cfg1).await?); + let (b_notifier, mut b_connected) = on_connected(); + b_agent.on_connection_state_change(b_notifier); + + connect_with_vnet(&a_agent, &b_agent).await?; + let _ = a_connected.recv().await; + let _ = b_connected.recv().await; + + a_agent.close().await?; + b_agent.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_multicast_dns_static_host_name() -> Result<()> { + let cfg0 = AgentConfig { + network_types: vec![NetworkType::Udp4], + candidate_types: vec![CandidateType::Host], + multicast_dns_mode: MulticastDnsMode::QueryAndGather, + multicast_dns_host_name: "invalidHostName".to_owned(), + ..Default::default() + }; + if let Err(err) = Agent::new(cfg0).await { + assert_eq!(err, Error::ErrInvalidMulticastDnshostName); + } else { + panic!("expected error, but got ok"); + } + + let cfg1 = AgentConfig { + network_types: vec![NetworkType::Udp4], + candidate_types: vec![CandidateType::Host], + multicast_dns_mode: MulticastDnsMode::QueryAndGather, + multicast_dns_host_name: "validName.local".to_owned(), + ..Default::default() + }; + + let a = Agent::new(cfg1).await?; + + let (done_tx, mut done_rx) = mpsc::channel::<()>(1); + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + a.on_candidate(Box::new( + move |c: Option>| { + let done_tx_clone = Arc::clone(&done_tx); + Box::pin(async move { + if c.is_none() { + let mut tx = done_tx_clone.lock().await; + tx.take(); + } + }) + }, + )); + + a.gather_candidates()?; + + log::debug!("wait for gathering is done..."); + let _ = done_rx.recv().await; + log::debug!("gathering is done"); + + Ok(()) +} + +#[test] +fn test_generate_multicast_dnsname() -> Result<()> { + let name = generate_multicast_dns_name(); + + let re = Regex::new( + r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-4[0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}.local+$", + ); + + if let Ok(re) = re { + assert!( + re.is_match(&name), + "mDNS name must be UUID v4 + \".local\" suffix, got {name}" + ); + } else { + panic!("expected ok, but got err"); + } + + Ok(()) +} diff --git a/reserved/ice/src/mdns/mod.rs b/reserved/ice/src/mdns/mod.rs new file mode 100644 index 0000000..981d58f --- /dev/null +++ b/reserved/ice/src/mdns/mod.rs @@ -0,0 +1,71 @@ +#[cfg(test)] +mod mdns_test; + +use std::net::SocketAddr; +use std::str::FromStr; +use std::sync::Arc; + +use mdns::config::*; +use mdns::conn::*; +use uuid::Uuid; + +use crate::error::Result; + +/// Represents the different Multicast modes that ICE can run. +#[derive(PartialEq, Eq, Debug, Copy, Clone)] +pub enum MulticastDnsMode { + /// Means remote mDNS candidates will be discarded, and local host candidates will use IPs. + Disabled, + + /// Means remote mDNS candidates will be accepted, and local host candidates will use IPs. + QueryOnly, + + /// Means remote mDNS candidates will be accepted, and local host candidates will use mDNS. + QueryAndGather, +} + +impl Default for MulticastDnsMode { + fn default() -> Self { + Self::QueryOnly + } +} + +pub(crate) fn generate_multicast_dns_name() -> String { + // https://tools.ietf.org/id/draft-ietf-rtcweb-mdns-ice-candidates-02.html#gathering + // The unique name MUST consist of a version 4 UUID as defined in [RFC4122], followed by “.local”. + let u = Uuid::new_v4(); + format!("{u}.local") +} + +pub(crate) fn create_multicast_dns( + mdns_mode: MulticastDnsMode, + mdns_name: &str, + dest_addr: &str, +) -> Result>> { + let local_names = match mdns_mode { + MulticastDnsMode::QueryOnly => vec![], + MulticastDnsMode::QueryAndGather => vec![mdns_name.to_owned()], + MulticastDnsMode::Disabled => return Ok(None), + }; + + let addr = if dest_addr.is_empty() { + //TODO: why DEFAULT_DEST_ADDR doesn't work on Mac/Win? + if cfg!(target_os = "linux") { + SocketAddr::from_str(DEFAULT_DEST_ADDR)? + } else { + SocketAddr::from_str("0.0.0.0:5353")? + } + } else { + SocketAddr::from_str(dest_addr)? + }; + log::info!("mDNS is using {} as dest_addr", addr); + + let conn = DnsConn::server( + addr, + Config { + local_names, + ..Config::default() + }, + )?; + Ok(Some(Arc::new(conn))) +} diff --git a/reserved/ice/src/network_type/mod.rs b/reserved/ice/src/network_type/mod.rs new file mode 100644 index 0000000..fcd50f9 --- /dev/null +++ b/reserved/ice/src/network_type/mod.rs @@ -0,0 +1,148 @@ +#[cfg(test)] +mod network_type_test; + +use std::fmt; +use std::net::IpAddr; + +use serde::{Deserialize, Serialize}; + +use crate::error::*; + +pub(crate) const UDP: &str = "udp"; +pub(crate) const TCP: &str = "tcp"; + +#[must_use] +pub fn supported_network_types() -> Vec { + vec![ + NetworkType::Udp4, + NetworkType::Udp6, + //NetworkType::TCP4, + //NetworkType::TCP6, + ] +} + +/// Represents the type of network. +#[derive(PartialEq, Debug, Copy, Clone, Eq, Hash, Serialize, Deserialize)] +pub enum NetworkType { + #[serde(rename = "unspecified")] + Unspecified, + + /// Indicates UDP over IPv4. + #[serde(rename = "udp4")] + Udp4, + + /// Indicates UDP over IPv6. + #[serde(rename = "udp6")] + Udp6, + + /// Indicates TCP over IPv4. + #[serde(rename = "tcp4")] + Tcp4, + + /// Indicates TCP over IPv6. + #[serde(rename = "tcp6")] + Tcp6, +} + +impl From for NetworkType { + fn from(v: u8) -> Self { + match v { + 1 => Self::Udp4, + 2 => Self::Udp6, + 3 => Self::Tcp4, + 4 => Self::Tcp6, + _ => Self::Unspecified, + } + } +} + +impl fmt::Display for NetworkType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + Self::Udp4 => "udp4", + Self::Udp6 => "udp6", + Self::Tcp4 => "tcp4", + Self::Tcp6 => "tcp6", + Self::Unspecified => "unspecified", + }; + write!(f, "{s}") + } +} + +impl Default for NetworkType { + fn default() -> Self { + Self::Unspecified + } +} + +impl NetworkType { + /// Returns true when network is UDP4 or UDP6. + #[must_use] + pub fn is_udp(self) -> bool { + self == Self::Udp4 || self == Self::Udp6 + } + + /// Returns true when network is TCP4 or TCP6. + #[must_use] + pub fn is_tcp(self) -> bool { + self == Self::Tcp4 || self == Self::Tcp6 + } + + /// Returns the short network description. + #[must_use] + pub fn network_short(self) -> String { + match self { + Self::Udp4 | Self::Udp6 => UDP.to_owned(), + Self::Tcp4 | Self::Tcp6 => TCP.to_owned(), + Self::Unspecified => "Unspecified".to_owned(), + } + } + + /// Returns true if the network is reliable. + #[must_use] + pub const fn is_reliable(self) -> bool { + match self { + Self::Tcp4 | Self::Tcp6 => true, + Self::Udp4 | Self::Udp6 | Self::Unspecified => false, + } + } + + /// Returns whether the network type is IPv4 or not. + #[must_use] + pub const fn is_ipv4(self) -> bool { + match self { + Self::Udp4 | Self::Tcp4 => true, + Self::Udp6 | Self::Tcp6 | Self::Unspecified => false, + } + } + + /// Returns whether the network type is IPv6 or not. + #[must_use] + pub const fn is_ipv6(self) -> bool { + match self { + Self::Udp6 | Self::Tcp6 => true, + Self::Udp4 | Self::Tcp4 | Self::Unspecified => false, + } + } +} + +/// Determines the type of network based on the short network string and an IP address. +pub(crate) fn determine_network_type(network: &str, ip: &IpAddr) -> Result { + let ipv4 = ip.is_ipv4(); + let net = network.to_lowercase(); + if net.starts_with(UDP) { + if ipv4 { + Ok(NetworkType::Udp4) + } else { + Ok(NetworkType::Udp6) + } + } else if net.starts_with(TCP) { + if ipv4 { + Ok(NetworkType::Tcp4) + } else { + Ok(NetworkType::Tcp6) + } + } else { + Err(Error::ErrDetermineNetworkType) + } +} diff --git a/reserved/ice/src/network_type/network_type_test.rs b/reserved/ice/src/network_type/network_type_test.rs new file mode 100644 index 0000000..fa2a91d --- /dev/null +++ b/reserved/ice/src/network_type/network_type_test.rs @@ -0,0 +1,95 @@ +use super::*; +use crate::error::Result; + +#[test] +fn test_network_type_parsing_success() -> Result<()> { + let ipv4: IpAddr = "192.168.0.1".parse().unwrap(); + let ipv6: IpAddr = "fe80::a3:6ff:fec4:5454".parse().unwrap(); + + let tests = vec![ + ("lowercase UDP4", "udp", ipv4, NetworkType::Udp4), + ("uppercase UDP4", "UDP", ipv4, NetworkType::Udp4), + ("lowercase UDP6", "udp", ipv6, NetworkType::Udp6), + ("uppercase UDP6", "UDP", ipv6, NetworkType::Udp6), + ]; + + for (name, in_network, in_ip, expected) in tests { + let actual = determine_network_type(in_network, &in_ip)?; + + assert_eq!( + actual, expected, + "NetworkTypeParsing: '{name}' -- input:{in_network} expected:{expected} actual:{actual}" + ); + } + + Ok(()) +} + +#[test] +fn test_network_type_parsing_failure() -> Result<()> { + let ipv6: IpAddr = "fe80::a3:6ff:fec4:5454".parse().unwrap(); + + let tests = vec![("invalid network", "junkNetwork", ipv6)]; + for (name, in_network, in_ip) in tests { + let result = determine_network_type(in_network, &in_ip); + assert!( + result.is_err(), + "NetworkTypeParsing should fail: '{name}' -- input:{in_network}", + ); + } + + Ok(()) +} + +#[test] +fn test_network_type_is_udp() -> Result<()> { + assert!(NetworkType::Udp4.is_udp()); + assert!(NetworkType::Udp6.is_udp()); + assert!(!NetworkType::Udp4.is_tcp()); + assert!(!NetworkType::Udp6.is_tcp()); + + Ok(()) +} + +#[test] +fn test_network_type_is_tcp() -> Result<()> { + assert!(NetworkType::Tcp4.is_tcp()); + assert!(NetworkType::Tcp6.is_tcp()); + assert!(!NetworkType::Tcp4.is_udp()); + assert!(!NetworkType::Tcp6.is_udp()); + + Ok(()) +} + +#[test] +fn test_network_type_serialization() { + let tests = vec![ + (NetworkType::Tcp4, "\"tcp4\""), + (NetworkType::Tcp6, "\"tcp6\""), + (NetworkType::Udp4, "\"udp4\""), + (NetworkType::Udp6, "\"udp6\""), + (NetworkType::Unspecified, "\"unspecified\""), + ]; + + for (network_type, expected_string) in tests { + assert_eq!( + expected_string.to_string(), + serde_json::to_string(&network_type).unwrap() + ); + } +} + +#[test] +fn test_network_type_to_string() { + let tests = vec![ + (NetworkType::Tcp4, "tcp4"), + (NetworkType::Tcp6, "tcp6"), + (NetworkType::Udp4, "udp4"), + (NetworkType::Udp6, "udp6"), + (NetworkType::Unspecified, "unspecified"), + ]; + + for (network_type, expected_string) in tests { + assert_eq!(network_type.to_string(), expected_string); + } +} diff --git a/reserved/ice/src/priority/mod.rs b/reserved/ice/src/priority/mod.rs new file mode 100644 index 0000000..8a00c81 --- /dev/null +++ b/reserved/ice/src/priority/mod.rs @@ -0,0 +1,36 @@ +#[cfg(test)] +mod priority_test; + +use stun::attributes::ATTR_PRIORITY; +use stun::checks::*; +use stun::message::*; + +/// Represents PRIORITY attribute. +#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)] +pub struct PriorityAttr(pub u32); + +const PRIORITY_SIZE: usize = 4; // 32 bit + +impl Setter for PriorityAttr { + // add_to adds PRIORITY attribute to message. + fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { + let mut v = vec![0_u8; PRIORITY_SIZE]; + v.copy_from_slice(&self.0.to_be_bytes()); + m.add(ATTR_PRIORITY, &v); + Ok(()) + } +} + +impl PriorityAttr { + /// Decodes PRIORITY attribute from message. + pub fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { + let v = m.get(ATTR_PRIORITY)?; + + check_size(ATTR_PRIORITY, v.len(), PRIORITY_SIZE)?; + + let p = u32::from_be_bytes([v[0], v[1], v[2], v[3]]); + self.0 = p; + + Ok(()) + } +} diff --git a/reserved/ice/src/priority/priority_test.rs b/reserved/ice/src/priority/priority_test.rs new file mode 100644 index 0000000..231ca7c --- /dev/null +++ b/reserved/ice/src/priority/priority_test.rs @@ -0,0 +1,39 @@ +use super::*; +use crate::error::Result; + +#[test] +fn test_priority_get_from() -> Result<()> { + let mut m = Message::new(); + let mut p = PriorityAttr::default(); + let result = p.get_from(&m); + if let Err(err) = result { + assert_eq!(err, stun::Error::ErrAttributeNotFound, "unexpected error"); + } else { + panic!("expected error, but got ok"); + } + + m.build(&[Box::new(BINDING_REQUEST), Box::new(p)])?; + + let mut m1 = Message::new(); + m1.write(&m.raw)?; + + let mut p1 = PriorityAttr::default(); + p1.get_from(&m1)?; + + assert_eq!(p1, p, "not equal"); + + //"IncorrectSize" + { + let mut m3 = Message::new(); + m3.add(ATTR_PRIORITY, &[0; 100]); + let mut p2 = PriorityAttr::default(); + let result = p2.get_from(&m3); + if let Err(err) = result { + assert!(is_attr_size_invalid(&err), "should error"); + } else { + panic!("expected error, but got ok"); + } + } + + Ok(()) +} diff --git a/reserved/ice/src/rand/mod.rs b/reserved/ice/src/rand/mod.rs new file mode 100644 index 0000000..84001a8 --- /dev/null +++ b/reserved/ice/src/rand/mod.rs @@ -0,0 +1,48 @@ +#[cfg(test)] +mod rand_test; + +use rand::{thread_rng, Rng}; + +const RUNES_ALPHA: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; +const RUNES_CANDIDATE_ID_FOUNDATION: &[u8] = + b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789/+"; + +const LEN_UFRAG: usize = 16; +const LEN_PWD: usize = 32; + +//TODO: generates a random string for cryptographic usage. +pub fn generate_crypto_random_string(n: usize, runes: &[u8]) -> String { + let mut rng = thread_rng(); + + let rand_string: String = (0..n) + .map(|_| { + let idx = rng.gen_range(0..runes.len()); + runes[idx] as char + }) + .collect(); + + rand_string +} + +/// https://tools.ietf.org/html/rfc5245#section-15.1 +/// candidate-id = "candidate" ":" foundation +/// foundation = 1*32ice-char +/// ice-char = ALPHA / DIGIT / "+" / "/" +pub fn generate_cand_id() -> String { + format!( + "candidate:{}", + generate_crypto_random_string(32, RUNES_CANDIDATE_ID_FOUNDATION) + ) +} + +/// Generates ICE pwd. +/// This internally uses `generate_crypto_random_string`. +pub fn generate_pwd() -> String { + generate_crypto_random_string(LEN_PWD, RUNES_ALPHA) +} + +/// ICE user fragment. +/// This internally uses `generate_crypto_random_string`. +pub fn generate_ufrag() -> String { + generate_crypto_random_string(LEN_UFRAG, RUNES_ALPHA) +} diff --git a/reserved/ice/src/rand/rand_test.rs b/reserved/ice/src/rand/rand_test.rs new file mode 100644 index 0000000..bf2fdca --- /dev/null +++ b/reserved/ice/src/rand/rand_test.rs @@ -0,0 +1,77 @@ +use std::sync::Arc; + +use tokio::sync::Mutex; +use waitgroup::WaitGroup; + +use super::*; +use crate::error::Result; + +#[tokio::test] +async fn test_random_generator_collision() -> Result<()> { + let test_cases = vec![ + ( + "CandidateID", + 0, /*||-> String { + generate_cand_id() + },*/ + ), + ( + "PWD", 1, /*||-> String { + generate_pwd() + },*/ + ), + ( + "Ufrag", 2, /*|| ->String { + generate_ufrag() + },*/ + ), + ]; + + const N: usize = 10; + const ITERATION: usize = 10; + + for (name, test_case) in test_cases { + for _ in 0..ITERATION { + let rands = Arc::new(Mutex::new(vec![])); + + // Create a new wait group. + let wg = WaitGroup::new(); + + for _ in 0..N { + let w = wg.worker(); + let rs = Arc::clone(&rands); + + tokio::spawn(async move { + let _d = w; + + let s = if test_case == 0 { + generate_cand_id() + } else if test_case == 1 { + generate_pwd() + } else { + generate_ufrag() + }; + + let mut r = rs.lock().await; + r.push(s); + }); + } + wg.wait().await; + + let rs = rands.lock().await; + assert_eq!(rs.len(), N, "{name} Failed to generate randoms"); + + for i in 0..N { + for j in i + 1..N { + assert_ne!( + rs[i], rs[j], + "{}: generateRandString caused collision: {} == {}", + name, rs[i], rs[j], + ); + } + } + } + } + + Ok(()) +} diff --git a/reserved/ice/src/state/mod.rs b/reserved/ice/src/state/mod.rs new file mode 100644 index 0000000..bf21877 --- /dev/null +++ b/reserved/ice/src/state/mod.rs @@ -0,0 +1,112 @@ +#[cfg(test)] +mod state_test; + +use std::fmt; + +/// An enum showing the state of a ICE Connection List of supported States. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum ConnectionState { + Unspecified, + + /// ICE agent is gathering addresses. + New, + + /// ICE agent has been given local and remote candidates, and is attempting to find a match. + Checking, + + /// ICE agent has a pairing, but is still checking other pairs. + Connected, + + /// ICE agent has finished. + Completed, + + /// ICE agent never could successfully connect. + Failed, + + /// ICE agent connected successfully, but has entered a failed state. + Disconnected, + + /// ICE agent has finished and is no longer handling requests. + Closed, +} + +impl Default for ConnectionState { + fn default() -> Self { + Self::Unspecified + } +} + +impl fmt::Display for ConnectionState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + Self::Unspecified => "Unspecified", + Self::New => "New", + Self::Checking => "Checking", + Self::Connected => "Connected", + Self::Completed => "Completed", + Self::Failed => "Failed", + Self::Disconnected => "Disconnected", + Self::Closed => "Closed", + }; + write!(f, "{s}") + } +} + +impl From for ConnectionState { + fn from(v: u8) -> Self { + match v { + 1 => Self::New, + 2 => Self::Checking, + 3 => Self::Connected, + 4 => Self::Completed, + 5 => Self::Failed, + 6 => Self::Disconnected, + 7 => Self::Closed, + _ => Self::Unspecified, + } + } +} + +/// Describes the state of the candidate gathering process. +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum GatheringState { + Unspecified, + + /// Indicates candidate gathering is not yet started. + New, + + /// Indicates candidate gathering is ongoing. + Gathering, + + /// Indicates candidate gathering has been completed. + Complete, +} + +impl From for GatheringState { + fn from(v: u8) -> Self { + match v { + 1 => Self::New, + 2 => Self::Gathering, + 3 => Self::Complete, + _ => Self::Unspecified, + } + } +} + +impl Default for GatheringState { + fn default() -> Self { + Self::Unspecified + } +} + +impl fmt::Display for GatheringState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + Self::New => "new", + Self::Gathering => "gathering", + Self::Complete => "complete", + Self::Unspecified => "unspecified", + }; + write!(f, "{s}") + } +} diff --git a/reserved/ice/src/state/state_test.rs b/reserved/ice/src/state/state_test.rs new file mode 100644 index 0000000..9e93820 --- /dev/null +++ b/reserved/ice/src/state/state_test.rs @@ -0,0 +1,46 @@ +use super::*; +use crate::error::Result; + +#[test] +fn test_connected_state_string() -> Result<()> { + let tests = vec![ + (ConnectionState::Unspecified, "Unspecified"), + (ConnectionState::New, "New"), + (ConnectionState::Checking, "Checking"), + (ConnectionState::Connected, "Connected"), + (ConnectionState::Completed, "Completed"), + (ConnectionState::Failed, "Failed"), + (ConnectionState::Disconnected, "Disconnected"), + (ConnectionState::Closed, "Closed"), + ]; + + for (connection_state, expected_string) in tests { + assert_eq!( + connection_state.to_string(), + expected_string, + "testCase: {expected_string} vs {connection_state}", + ) + } + + Ok(()) +} + +#[test] +fn test_gathering_state_string() -> Result<()> { + let tests = vec![ + (GatheringState::Unspecified, "unspecified"), + (GatheringState::New, "new"), + (GatheringState::Gathering, "gathering"), + (GatheringState::Complete, "complete"), + ]; + + for (gathering_state, expected_string) in tests { + assert_eq!( + gathering_state.to_string(), + expected_string, + "testCase: {expected_string} vs {gathering_state}", + ) + } + + Ok(()) +} diff --git a/reserved/ice/src/stats/mod.rs b/reserved/ice/src/stats/mod.rs new file mode 100644 index 0000000..e3fc406 --- /dev/null +++ b/reserved/ice/src/stats/mod.rs @@ -0,0 +1,178 @@ +use tokio::time::Instant; + +use crate::candidate::*; +use crate::network_type::*; + +// CandidatePairStats contains ICE candidate pair statistics +#[derive(Debug, Clone)] +pub struct CandidatePairStats { + // timestamp is the timestamp associated with this object. + pub timestamp: Instant, + + // local_candidate_id is the id of the local candidate + pub local_candidate_id: String, + + // remote_candidate_id is the id of the remote candidate + pub remote_candidate_id: String, + + // state represents the state of the checklist for the local and remote + // candidates in a pair. + pub state: CandidatePairState, + + // nominated is true when this valid pair that should be used for media + // if it is the highest-priority one amongst those whose nominated flag is set + pub nominated: bool, + + // packets_sent represents the total number of packets sent on this candidate pair. + pub packets_sent: u32, + + // packets_received represents the total number of packets received on this candidate pair. + pub packets_received: u32, + + // bytes_sent represents the total number of payload bytes sent on this candidate pair + // not including headers or padding. + pub bytes_sent: u64, + + // bytes_received represents the total number of payload bytes received on this candidate pair + // not including headers or padding. + pub bytes_received: u64, + + // last_packet_sent_timestamp represents the timestamp at which the last packet was + // sent on this particular candidate pair, excluding STUN packets. + pub last_packet_sent_timestamp: Instant, + + // last_packet_received_timestamp represents the timestamp at which the last packet + // was received on this particular candidate pair, excluding STUN packets. + pub last_packet_received_timestamp: Instant, + + // first_request_timestamp represents the timestamp at which the first STUN request + // was sent on this particular candidate pair. + pub first_request_timestamp: Instant, + + // last_request_timestamp represents the timestamp at which the last STUN request + // was sent on this particular candidate pair. The average interval between two + // consecutive connectivity checks sent can be calculated with + // (last_request_timestamp - first_request_timestamp) / requests_sent. + pub last_request_timestamp: Instant, + + // last_response_timestamp represents the timestamp at which the last STUN response + // was received on this particular candidate pair. + pub last_response_timestamp: Instant, + + // total_round_trip_time represents the sum of all round trip time measurements + // in seconds since the beginning of the session, based on STUN connectivity + // check responses (responses_received), including those that reply to requests + // that are sent in order to verify consent. The average round trip time can + // be computed from total_round_trip_time by dividing it by responses_received. + pub total_round_trip_time: f64, + + // current_round_trip_time represents the latest round trip time measured in seconds, + // computed from both STUN connectivity checks, including those that are sent + // for consent verification. + pub current_round_trip_time: f64, + + // available_outgoing_bitrate is calculated by the underlying congestion control + // by combining the available bitrate for all the outgoing RTP streams using + // this candidate pair. The bitrate measurement does not count the size of the + // ip or other transport layers like TCP or UDP. It is similar to the TIAS defined + // in RFC 3890, i.e., it is measured in bits per second and the bitrate is calculated + // over a 1 second window. + pub available_outgoing_bitrate: f64, + + // available_incoming_bitrate is calculated by the underlying congestion control + // by combining the available bitrate for all the incoming RTP streams using + // this candidate pair. The bitrate measurement does not count the size of the + // ip or other transport layers like TCP or UDP. It is similar to the TIAS defined + // in RFC 3890, i.e., it is measured in bits per second and the bitrate is + // calculated over a 1 second window. + pub available_incoming_bitrate: f64, + + // circuit_breaker_trigger_count represents the number of times the circuit breaker + // is triggered for this particular 5-tuple, ceasing transmission. + pub circuit_breaker_trigger_count: u32, + + // requests_received represents the total number of connectivity check requests + // received (including retransmissions). It is impossible for the receiver to + // tell whether the request was sent in order to check connectivity or check + // consent, so all connectivity checks requests are counted here. + pub requests_received: u64, + + // requests_sent represents the total number of connectivity check requests + // sent (not including retransmissions). + pub requests_sent: u64, + + // responses_received represents the total number of connectivity check responses received. + pub responses_received: u64, + + // responses_sent epresents the total number of connectivity check responses sent. + // Since we cannot distinguish connectivity check requests and consent requests, + // all responses are counted. + pub responses_sent: u64, + + // retransmissions_received represents the total number of connectivity check + // request retransmissions received. + pub retransmissions_received: u64, + + // retransmissions_sent represents the total number of connectivity check + // request retransmissions sent. + pub retransmissions_sent: u64, + + // consent_requests_sent represents the total number of consent requests sent. + pub consent_requests_sent: u64, + + // consent_expired_timestamp represents the timestamp at which the latest valid + // STUN binding response expired. + pub consent_expired_timestamp: Instant, +} + +// CandidateStats contains ICE candidate statistics related to the ICETransport objects. +#[derive(Debug, Clone)] +pub struct CandidateStats { + // timestamp is the timestamp associated with this object. + pub timestamp: Instant, + + // id is the candidate id + pub id: String, + + // network_type represents the type of network interface used by the base of a + // local candidate (the address the ICE agent sends from). Only present for + // local candidates; it's not possible to know what type of network interface + // a remote candidate is using. + // + // Note: + // This stat only tells you about the network interface used by the first "hop"; + // it's possible that a connection will be bottlenecked by another type of network. + // For example, when using Wi-Fi tethering, the networkType of the relevant candidate + // would be "wifi", even when the next hop is over a cellular connection. + pub network_type: NetworkType, + + // ip is the ip address of the candidate, allowing for IPv4 addresses and + // IPv6 addresses, but fully qualified domain names (FQDNs) are not allowed. + pub ip: String, + + // port is the port number of the candidate. + pub port: u16, + + // candidate_type is the "Type" field of the ICECandidate. + pub candidate_type: CandidateType, + + // priority is the "priority" field of the ICECandidate. + pub priority: u32, + + // url is the url of the TURN or STUN server indicated in the that translated + // this ip address. It is the url address surfaced in an PeerConnectionICEEvent. + pub url: String, + + // relay_protocol is the protocol used by the endpoint to communicate with the + // TURN server. This is only present for local candidates. Valid values for + // the TURN url protocol is one of udp, tcp, or tls. + pub relay_protocol: String, + + // deleted is true if the candidate has been deleted/freed. For host candidates, + // this means that any network resources (typically a socket) associated with the + // candidate have been released. For TURN candidates, this means the TURN allocation + // is no longer active. + // + // Only defined for local candidates. For remote candidates, this property is not applicable. + pub deleted: bool, +} diff --git a/reserved/ice/src/tcp_type/mod.rs b/reserved/ice/src/tcp_type/mod.rs new file mode 100644 index 0000000..11f7bdb --- /dev/null +++ b/reserved/ice/src/tcp_type/mod.rs @@ -0,0 +1,48 @@ +#[cfg(test)] +mod tcp_type_test; + +use std::fmt; + +// TCPType is the type of ICE TCP candidate as described in +// ttps://tools.ietf.org/html/rfc6544#section-4.5 +#[derive(PartialEq, Eq, Debug, Copy, Clone)] +pub enum TcpType { + /// The default value. For example UDP candidates do not need this field. + Unspecified, + /// Active TCP candidate, which initiates TCP connections. + Active, + /// Passive TCP candidate, only accepts TCP connections. + Passive, + /// Like `Active` and `Passive` at the same time. + SimultaneousOpen, +} + +// from creates a new TCPType from string. +impl From<&str> for TcpType { + fn from(raw: &str) -> Self { + match raw { + "active" => Self::Active, + "passive" => Self::Passive, + "so" => Self::SimultaneousOpen, + _ => Self::Unspecified, + } + } +} + +impl fmt::Display for TcpType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + Self::Active => "active", + Self::Passive => "passive", + Self::SimultaneousOpen => "so", + Self::Unspecified => "unspecified", + }; + write!(f, "{s}") + } +} + +impl Default for TcpType { + fn default() -> Self { + Self::Unspecified + } +} diff --git a/reserved/ice/src/tcp_type/tcp_type_test.rs b/reserved/ice/src/tcp_type/tcp_type_test.rs new file mode 100644 index 0000000..26a189e --- /dev/null +++ b/reserved/ice/src/tcp_type/tcp_type_test.rs @@ -0,0 +1,18 @@ +use super::*; +use crate::error::Result; + +#[test] +fn test_tcp_type() -> Result<()> { + //assert_eq!(tcpType, TCPType::Unspecified) + assert_eq!(TcpType::from("active"), TcpType::Active); + assert_eq!(TcpType::from("passive"), TcpType::Passive); + assert_eq!(TcpType::from("so"), TcpType::SimultaneousOpen); + assert_eq!(TcpType::from("something else"), TcpType::Unspecified); + + assert_eq!(TcpType::Unspecified.to_string(), "unspecified"); + assert_eq!(TcpType::Active.to_string(), "active"); + assert_eq!(TcpType::Passive.to_string(), "passive"); + assert_eq!(TcpType::SimultaneousOpen.to_string(), "so"); + + Ok(()) +} diff --git a/reserved/ice/src/udp_mux/mod.rs b/reserved/ice/src/udp_mux/mod.rs new file mode 100644 index 0000000..10af0b1 --- /dev/null +++ b/reserved/ice/src/udp_mux/mod.rs @@ -0,0 +1,338 @@ +use std::collections::HashMap; +use std::io::ErrorKind; +use std::net::SocketAddr; +use std::sync::{Arc, Weak}; + +use async_trait::async_trait; +use tokio::sync::{watch, Mutex}; +use util::sync::RwLock; +use util::{Conn, Error}; + +mod udp_mux_conn; +pub use udp_mux_conn::{UDPMuxConn, UDPMuxConnParams, UDPMuxWriter}; + +#[cfg(test)] +mod udp_mux_test; + +mod socket_addr_ext; + +use stun::attributes::ATTR_USERNAME; +use stun::message::{is_message as is_stun_message, Message as STUNMessage}; + +use crate::candidate::RECEIVE_MTU; + +/// Normalize a target socket addr for sending over a given local socket addr. This is useful when +/// a dual stack socket is used, in which case an IPv4 target needs to be mapped to an IPv6 +/// address. +fn normalize_socket_addr(target: &SocketAddr, socket_addr: &SocketAddr) -> SocketAddr { + match (target, socket_addr) { + (SocketAddr::V4(target_ipv4), SocketAddr::V6(_)) => { + let ipv6_mapped = target_ipv4.ip().to_ipv6_mapped(); + + SocketAddr::new(std::net::IpAddr::V6(ipv6_mapped), target_ipv4.port()) + } + // This will fail later if target is IPv6 and socket is IPv4, we ignore it here + (_, _) => *target, + } +} + +#[async_trait] +pub trait UDPMux { + /// Close the muxing. + async fn close(&self) -> Result<(), Error>; + + /// Get the underlying connection for a given ufrag. + async fn get_conn(self: Arc, ufrag: &str) -> Result, Error>; + + /// Remove the underlying connection for a given ufrag. + async fn remove_conn_by_ufrag(&self, ufrag: &str); +} + +pub struct UDPMuxParams { + conn: Box, +} + +impl UDPMuxParams { + pub fn new(conn: C) -> Self + where + C: Conn + Send + Sync + 'static, + { + Self { + conn: Box::new(conn), + } + } +} + +pub struct UDPMuxDefault { + /// The params this instance is configured with. + /// Contains the underlying UDP socket in use + params: UDPMuxParams, + + /// Maps from ufrag to the underlying connection. + conns: Mutex>, + + /// Maps from ip address to the underlying connection. + address_map: RwLock>, + + // Close sender + closed_watch_tx: Mutex>>, + + /// Close reciever + closed_watch_rx: watch::Receiver<()>, +} + +impl UDPMuxDefault { + pub fn new(params: UDPMuxParams) -> Arc { + let (closed_watch_tx, closed_watch_rx) = watch::channel(()); + + let mux = Arc::new(Self { + params, + conns: Mutex::default(), + address_map: RwLock::default(), + closed_watch_tx: Mutex::new(Some(closed_watch_tx)), + closed_watch_rx: closed_watch_rx.clone(), + }); + + let cloned_mux = Arc::clone(&mux); + cloned_mux.start_conn_worker(closed_watch_rx); + + mux + } + + pub async fn is_closed(&self) -> bool { + self.closed_watch_tx.lock().await.is_none() + } + + /// Create a muxed connection for a given ufrag. + fn create_muxed_conn(self: &Arc, ufrag: &str) -> Result { + let local_addr = self.params.conn.local_addr()?; + + let params = UDPMuxConnParams { + local_addr, + key: ufrag.into(), + udp_mux: Arc::downgrade(self) as Weak, + }; + + Ok(UDPMuxConn::new(params)) + } + + async fn conn_from_stun_message(&self, buffer: &[u8], addr: &SocketAddr) -> Option { + let (result, message) = { + let mut m = STUNMessage::new(); + + (m.unmarshal_binary(buffer), m) + }; + + match result { + Err(err) => { + log::warn!("Failed to handle decode ICE from {}: {}", addr, err); + None + } + Ok(_) => { + let (attr, found) = message.attributes.get(ATTR_USERNAME); + if !found { + log::warn!("No username attribute in STUN message from {}", &addr); + return None; + } + + let s = match String::from_utf8(attr.value) { + // Per the RFC this shouldn't happen + // https://datatracker.ietf.org/doc/html/rfc5389#section-15.3 + Err(err) => { + log::warn!( + "Failed to decode USERNAME from STUN message as UTF-8: {}", + err + ); + return None; + } + Ok(s) => s, + }; + + let conns = self.conns.lock().await; + let conn = s + .split(':') + .next() + .and_then(|ufrag| conns.get(ufrag)) + .map(Clone::clone); + + conn + } + } + } + + fn start_conn_worker(self: Arc, mut closed_watch_rx: watch::Receiver<()>) { + tokio::spawn(async move { + let mut buffer = [0u8; RECEIVE_MTU]; + + loop { + let loop_self = Arc::clone(&self); + let conn = &loop_self.params.conn; + + tokio::select! { + res = conn.recv_from(&mut buffer) => { + match res { + Ok((len, addr)) => { + // Find connection based on previously having seen this source address + let conn = { + let address_map = loop_self + .address_map + .read(); + + address_map.get(&addr).map(Clone::clone) + }; + + let conn = match conn { + // If we couldn't find the connection based on source address, see if + // this is a STUN mesage and if so if we can find the connection based on ufrag. + None if is_stun_message(&buffer) => { + loop_self.conn_from_stun_message(&buffer, &addr).await + } + s @ Some(_) => s, + _ => None, + }; + + match conn { + None => { + log::trace!("Dropping packet from {}", &addr); + } + Some(conn) => { + if let Err(err) = conn.write_packet(&buffer[..len], addr).await { + log::error!("Failed to write packet: {}", err); + } + } + } + } + Err(Error::Io(err)) if err.0.kind() == ErrorKind::TimedOut => continue, + Err(err) => { + log::error!("Could not read udp packet: {}", err); + break; + } + } + } + _ = closed_watch_rx.changed() => { + return; + } + } + } + }); + } +} + +#[async_trait] +impl UDPMux for UDPMuxDefault { + async fn close(&self) -> Result<(), Error> { + if self.is_closed().await { + return Err(Error::ErrAlreadyClosed); + } + + let mut closed_tx = self.closed_watch_tx.lock().await; + + if let Some(tx) = closed_tx.take() { + let _ = tx.send(()); + drop(closed_tx); + + let old_conns = { + let mut conns = self.conns.lock().await; + + std::mem::take(&mut (*conns)) + }; + + // NOTE: We don't wait for these closure to complete + for (_, conn) in old_conns { + conn.close(); + } + + { + let mut address_map = self.address_map.write(); + + // NOTE: This is important, we need to drop all instances of `UDPMuxConn` to + // avoid a retain cycle due to the use of [`std::sync::Arc`] on both sides. + let _ = std::mem::take(&mut (*address_map)); + } + } + + Ok(()) + } + + async fn get_conn(self: Arc, ufrag: &str) -> Result, Error> { + if self.is_closed().await { + return Err(Error::ErrUseClosedNetworkConn); + } + + { + let mut conns = self.conns.lock().await; + if let Some(conn) = conns.get(ufrag) { + // UDPMuxConn uses `Arc` internally so it's cheap to clone, but because + // we implement `Conn` we need to further wrap it in an `Arc` here. + return Ok(Arc::new(conn.clone()) as Arc); + } + + let muxed_conn = self.create_muxed_conn(ufrag)?; + let mut close_rx = muxed_conn.close_rx(); + let cloned_self = Arc::clone(&self); + let cloned_ufrag = ufrag.to_string(); + tokio::spawn(async move { + let _ = close_rx.changed().await; + + // Arc needed + cloned_self.remove_conn_by_ufrag(&cloned_ufrag).await; + }); + + conns.insert(ufrag.into(), muxed_conn.clone()); + + Ok(Arc::new(muxed_conn) as Arc) + } + } + + async fn remove_conn_by_ufrag(&self, ufrag: &str) { + // Pion's ice implementation has both `RemoveConnByFrag` and `RemoveConn`, but since `conns` + // is keyed on `ufrag` their implementation is equivalent. + + let removed_conn = { + let mut conns = self.conns.lock().await; + conns.remove(ufrag) + }; + + if let Some(conn) = removed_conn { + let mut address_map = self.address_map.write(); + + for address in conn.get_addresses() { + address_map.remove(&address); + } + } + } +} + +#[async_trait] +impl UDPMuxWriter for UDPMuxDefault { + async fn register_conn_for_address(&self, conn: &UDPMuxConn, addr: SocketAddr) { + if self.is_closed().await { + return; + } + + let key = conn.key(); + { + let mut addresses = self.address_map.write(); + + addresses + .entry(addr) + .and_modify(|e| { + if e.key() != key { + e.remove_address(&addr); + *e = conn.clone(); + } + }) + .or_insert_with(|| conn.clone()); + } + + log::debug!("Registered {} for {}", addr, key); + } + + async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> Result { + self.params + .conn + .send_to(buf, *target) + .await + .map_err(Into::into) + } +} diff --git a/reserved/ice/src/udp_mux/socket_addr_ext.rs b/reserved/ice/src/udp_mux/socket_addr_ext.rs new file mode 100644 index 0000000..7290b1b --- /dev/null +++ b/reserved/ice/src/udp_mux/socket_addr_ext.rs @@ -0,0 +1,246 @@ +use std::array::TryFromSliceError; +use std::convert::TryInto; +use std::net::SocketAddr; + +use util::Error; + +pub(super) trait SocketAddrExt { + ///Encode a representation of `self` into the buffer and return the length of this encoded + ///version. + /// + /// The buffer needs to be at least 27 bytes in length. + fn encode(&self, buffer: &mut [u8]) -> Result; + + /// Decode a `SocketAddr` from a buffer. The encoding should have previously been done with + /// [`SocketAddrExt::encode`]. + fn decode(buffer: &[u8]) -> Result; +} + +const IPV4_MARKER: u8 = 4; +const IPV4_ADDRESS_SIZE: usize = 7; +const IPV6_MARKER: u8 = 6; +const IPV6_ADDRESS_SIZE: usize = 27; + +pub(super) const MAX_ADDR_SIZE: usize = IPV6_ADDRESS_SIZE; + +impl SocketAddrExt for SocketAddr { + fn encode(&self, buffer: &mut [u8]) -> Result { + use std::net::SocketAddr::{V4, V6}; + + if buffer.len() < MAX_ADDR_SIZE { + return Err(Error::ErrBufferShort); + } + + match self { + V4(addr) => { + let marker = IPV4_MARKER; + let ip: [u8; 4] = addr.ip().octets(); + let port: u16 = addr.port(); + + buffer[0] = marker; + buffer[1..5].copy_from_slice(&ip); + buffer[5..7].copy_from_slice(&port.to_le_bytes()); + + Ok(7) + } + V6(addr) => { + let marker = IPV6_MARKER; + let ip: [u8; 16] = addr.ip().octets(); + let port: u16 = addr.port(); + let flowinfo = addr.flowinfo(); + let scope_id = addr.scope_id(); + + buffer[0] = marker; + buffer[1..17].copy_from_slice(&ip); + buffer[17..19].copy_from_slice(&port.to_le_bytes()); + buffer[19..23].copy_from_slice(&flowinfo.to_le_bytes()); + buffer[23..27].copy_from_slice(&scope_id.to_le_bytes()); + + Ok(MAX_ADDR_SIZE) + } + } + } + + fn decode(buffer: &[u8]) -> Result { + use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; + + match buffer[0] { + IPV4_MARKER => { + if buffer.len() < IPV4_ADDRESS_SIZE { + return Err(Error::ErrBufferShort); + } + + let ip_parts = &buffer[1..5]; + let port = match &buffer[5..7].try_into() { + Err(_) => return Err(Error::ErrFailedToParseIpaddr), + Ok(input) => u16::from_le_bytes(*input), + }; + + let ip = Ipv4Addr::new(ip_parts[0], ip_parts[1], ip_parts[2], ip_parts[3]); + + Ok(SocketAddr::V4(SocketAddrV4::new(ip, port))) + } + IPV6_MARKER => { + if buffer.len() < IPV6_ADDRESS_SIZE { + return Err(Error::ErrBufferShort); + } + + // Just to help the type system infer correctly + fn helper(b: &[u8]) -> Result<&[u8; 16], TryFromSliceError> { + b.try_into() + } + + let ip = match helper(&buffer[1..17]) { + Err(_) => return Err(Error::ErrFailedToParseIpaddr), + Ok(input) => Ipv6Addr::from(*input), + }; + let port = match &buffer[17..19].try_into() { + Err(_) => return Err(Error::ErrFailedToParseIpaddr), + Ok(input) => u16::from_le_bytes(*input), + }; + + let flowinfo = match &buffer[19..23].try_into() { + Err(_) => return Err(Error::ErrFailedToParseIpaddr), + Ok(input) => u32::from_le_bytes(*input), + }; + + let scope_id = match &buffer[23..27].try_into() { + Err(_) => return Err(Error::ErrFailedToParseIpaddr), + Ok(input) => u32::from_le_bytes(*input), + }; + + Ok(SocketAddr::V6(SocketAddrV6::new( + ip, port, flowinfo, scope_id, + ))) + } + _ => Err(Error::ErrFailedToParseIpaddr), + } + } +} + +#[cfg(test)] +mod test { + use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; + + use super::*; + + #[test] + fn test_ipv4() { + let ip = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from([56, 128, 35, 5]), 0x1234)); + + let mut buffer = [0_u8; MAX_ADDR_SIZE]; + let encoded_len = ip.encode(&mut buffer); + + assert_eq!(encoded_len, Ok(7)); + assert_eq!( + &buffer[0..7], + &[IPV4_MARKER, 56, 128, 35, 5, 0x34, 0x12][..] + ); + + let decoded = SocketAddr::decode(&buffer); + + assert_eq!(decoded, Ok(ip)); + } + + #[test] + fn test_ipv6() { + let ip = SocketAddr::V6(SocketAddrV6::new( + Ipv6Addr::from([ + 92, 114, 235, 3, 244, 64, 38, 111, 20, 100, 199, 241, 19, 174, 220, 123, + ]), + 0x1234, + 0x12345678, + 0x87654321, + )); + + let mut buffer = [0_u8; MAX_ADDR_SIZE]; + let encoded_len = ip.encode(&mut buffer); + + assert_eq!(encoded_len, Ok(27)); + assert_eq!( + &buffer[0..27], + &[ + IPV6_MARKER, // marker + // Start of ipv6 address + 92, + 114, + 235, + 3, + 244, + 64, + 38, + 111, + 20, + 100, + 199, + 241, + 19, + 174, + 220, + 123, + // LE port + 0x34, + 0x12, + // LE flowinfo + 0x78, + 0x56, + 0x34, + 0x12, + // LE scope_id + 0x21, + 0x43, + 0x65, + 0x87, + ][..] + ); + + let decoded = SocketAddr::decode(&buffer); + + assert_eq!(decoded, Ok(ip)); + } + + #[test] + fn test_encode_ipv4_with_short_buffer() { + let mut buffer = vec![0u8; IPV4_ADDRESS_SIZE - 1]; + let ip = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from([56, 128, 35, 5]), 0x1234)); + + let result = ip.encode(&mut buffer); + + assert_eq!(result, Err(Error::ErrBufferShort)); + } + + #[test] + fn test_encode_ipv6_with_short_buffer() { + let mut buffer = vec![0u8; MAX_ADDR_SIZE - 1]; + let ip = SocketAddr::V6(SocketAddrV6::new( + Ipv6Addr::from([ + 92, 114, 235, 3, 244, 64, 38, 111, 20, 100, 199, 241, 19, 174, 220, 123, + ]), + 0x1234, + 0x12345678, + 0x87654321, + )); + + let result = ip.encode(&mut buffer); + + assert_eq!(result, Err(Error::ErrBufferShort)); + } + + #[test] + fn test_decode_ipv4_with_short_buffer() { + let buffer = vec![IPV4_MARKER, 0]; + + let result = SocketAddr::decode(&buffer); + + assert_eq!(result, Err(Error::ErrBufferShort)); + } + + #[test] + fn test_decode_ipv6_with_short_buffer() { + let buffer = vec![IPV6_MARKER, 0]; + + let result = SocketAddr::decode(&buffer); + + assert_eq!(result, Err(Error::ErrBufferShort)); + } +} diff --git a/reserved/ice/src/udp_mux/udp_mux_conn.rs b/reserved/ice/src/udp_mux/udp_mux_conn.rs new file mode 100644 index 0000000..5066e2c --- /dev/null +++ b/reserved/ice/src/udp_mux/udp_mux_conn.rs @@ -0,0 +1,316 @@ +use std::collections::HashSet; +use std::convert::TryInto; +use std::io; +use std::net::SocketAddr; +use std::sync::{Arc, Weak}; + +use async_trait::async_trait; +use tokio::sync::watch; +use util::sync::Mutex; +use util::{Buffer, Conn, Error}; + +use super::socket_addr_ext::{SocketAddrExt, MAX_ADDR_SIZE}; +use super::{normalize_socket_addr, RECEIVE_MTU}; + +/// A trait for a [`UDPMuxConn`] to communicate with an UDP mux. +#[async_trait] +pub trait UDPMuxWriter { + /// Registers an address for the given connection. + async fn register_conn_for_address(&self, conn: &UDPMuxConn, addr: SocketAddr); + /// Sends the content of the buffer to the given target. + /// + /// Returns the number of bytes sent or an error, if any. + async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> Result; +} + +/// Parameters for a [`UDPMuxConn`]. +pub struct UDPMuxConnParams { + /// Local socket address. + pub local_addr: SocketAddr, + /// Static key identifying the connection. + pub key: String, + /// A `std::sync::Weak` reference to the UDP mux. + /// + /// NOTE: a non-owning reference should be used to prevent possible cycles. + pub udp_mux: Weak, +} + +type ConnResult = Result; + +/// A UDP mux connection. +#[derive(Clone)] +pub struct UDPMuxConn { + /// Close Receiver. A copy of this can be obtained via [`close_tx`]. + closed_watch_rx: watch::Receiver, + + inner: Arc, +} + +impl UDPMuxConn { + /// Creates a new [`UDPMuxConn`]. + pub fn new(params: UDPMuxConnParams) -> Self { + let (closed_watch_tx, closed_watch_rx) = watch::channel(false); + + Self { + closed_watch_rx, + inner: Arc::new(UDPMuxConnInner { + params, + closed_watch_tx: Mutex::new(Some(closed_watch_tx)), + addresses: Default::default(), + buffer: Buffer::new(0, 0), + }), + } + } + + /// Returns a key identifying this connection. + pub fn key(&self) -> &str { + &self.inner.params.key + } + + /// Writes data to the given address. Returns an error if the buffer is too short or there's an + /// encoding error. + pub async fn write_packet(&self, data: &[u8], addr: SocketAddr) -> ConnResult<()> { + // NOTE: Pion/ice uses Sync.Pool to optimise this. + let mut buffer = make_buffer(); + let mut offset = 0; + + if (data.len() + MAX_ADDR_SIZE) > (RECEIVE_MTU + MAX_ADDR_SIZE) { + return Err(Error::ErrBufferShort); + } + + // Format of buffer: | data len(2) | data bytes(dn) | addr len(2) | addr bytes(an) | + // Where the number in parenthesis indicate the number of bytes used + // `dn` and `an` are the length in bytes of data and addr respectively. + + // SAFETY: `data.len()` is at most RECEIVE_MTU(8192) - MAX_ADDR_SIZE(27) + buffer[0..2].copy_from_slice(&(data.len() as u16).to_le_bytes()[..]); + offset += 2; + + buffer[offset..offset + data.len()].copy_from_slice(data); + offset += data.len(); + + let len = addr.encode(&mut buffer[offset + 2..])?; + buffer[offset..offset + 2].copy_from_slice(&(len as u16).to_le_bytes()[..]); + offset += 2 + len; + + self.inner.buffer.write(&buffer[..offset]).await?; + + Ok(()) + } + + /// Returns true if this connection is closed. + pub fn is_closed(&self) -> bool { + self.inner.is_closed() + } + + /// Gets a copy of the close [`tokio::sync::watch::Receiver`] that fires when this + /// connection is closed. + pub fn close_rx(&self) -> watch::Receiver { + self.closed_watch_rx.clone() + } + + /// Closes this connection. + pub fn close(&self) { + self.inner.close(); + } + + /// Gets the list of the addresses associated with this connection. + pub fn get_addresses(&self) -> Vec { + self.inner.get_addresses() + } + + /// Registers a new address for this connection. + pub async fn add_address(&self, addr: SocketAddr) { + self.inner.add_address(addr); + if let Some(mux) = self.inner.params.udp_mux.upgrade() { + mux.register_conn_for_address(self, addr).await; + } + } + + /// Deregisters an address. + pub fn remove_address(&self, addr: &SocketAddr) { + self.inner.remove_address(addr) + } + + /// Returns true if the given address is associated with this connection. + pub fn contains_address(&self, addr: &SocketAddr) -> bool { + self.inner.contains_address(addr) + } +} + +struct UDPMuxConnInner { + params: UDPMuxConnParams, + + /// Close Sender. We'll send a value on this channel when we close + closed_watch_tx: Mutex>>, + + /// Remote addresses we've seen on this connection. + addresses: Mutex>, + + buffer: Buffer, +} + +impl UDPMuxConnInner { + // Sending/Recieving + async fn recv_from(&self, buf: &mut [u8]) -> ConnResult<(usize, SocketAddr)> { + // NOTE: Pion/ice uses Sync.Pool to optimise this. + let mut buffer = make_buffer(); + let mut offset = 0; + + let len = self.buffer.read(&mut buffer, None).await?; + // We always have at least. + // + // * 2 bytes for data len + // * 2 bytes for addr len + // * 7 bytes for an Ipv4 addr + if len < 11 { + return Err(Error::ErrBufferShort); + } + + let data_len: usize = buffer[..2] + .try_into() + .map(u16::from_le_bytes) + .map(From::from) + .unwrap(); + offset += 2; + + let total = 2 + data_len + 2 + 7; + if data_len > buf.len() || total > len { + return Err(Error::ErrBufferShort); + } + + buf[..data_len].copy_from_slice(&buffer[offset..offset + data_len]); + offset += data_len; + + let address_len: usize = buffer[offset..offset + 2] + .try_into() + .map(u16::from_le_bytes) + .map(From::from) + .unwrap(); + offset += 2; + + let addr = SocketAddr::decode(&buffer[offset..offset + address_len])?; + + Ok((data_len, addr)) + } + + async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> ConnResult { + if let Some(mux) = self.params.udp_mux.upgrade() { + mux.send_to(buf, target).await + } else { + Err(Error::Other(format!( + "wanted to send {} bytes to {}, but UDP mux is gone", + buf.len(), + target + ))) + } + } + + fn is_closed(&self) -> bool { + self.closed_watch_tx.lock().is_none() + } + + fn close(self: &Arc) { + let mut closed_tx = self.closed_watch_tx.lock(); + + if let Some(tx) = closed_tx.take() { + let _ = tx.send(true); + drop(closed_tx); + + let cloned_self = Arc::clone(self); + + { + let mut addresses = self.addresses.lock(); + *addresses = Default::default(); + } + + // NOTE: Alternatively we could wait on the buffer closing here so that + // our caller can wait for things to fully settle down + tokio::spawn(async move { + cloned_self.buffer.close().await; + }); + } + } + + fn local_addr(&self) -> SocketAddr { + self.params.local_addr + } + + // Address related methods + pub(super) fn get_addresses(&self) -> Vec { + let addresses = self.addresses.lock(); + + addresses.iter().copied().collect() + } + + pub(super) fn add_address(self: &Arc, addr: SocketAddr) { + { + let mut addresses = self.addresses.lock(); + addresses.insert(addr); + } + } + + pub(super) fn remove_address(&self, addr: &SocketAddr) { + { + let mut addresses = self.addresses.lock(); + addresses.remove(addr); + } + } + + pub(super) fn contains_address(&self, addr: &SocketAddr) -> bool { + let addresses = self.addresses.lock(); + + addresses.contains(addr) + } +} + +#[async_trait] +impl Conn for UDPMuxConn { + async fn connect(&self, _addr: SocketAddr) -> ConnResult<()> { + Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) + } + + async fn recv(&self, _buf: &mut [u8]) -> ConnResult { + Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) + } + + async fn recv_from(&self, buf: &mut [u8]) -> ConnResult<(usize, SocketAddr)> { + self.inner.recv_from(buf).await + } + + async fn send(&self, _buf: &[u8]) -> ConnResult { + Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) + } + + async fn send_to(&self, buf: &[u8], target: SocketAddr) -> ConnResult { + let normalized_target = normalize_socket_addr(&target, &self.inner.params.local_addr); + + if !self.contains_address(&normalized_target) { + self.add_address(normalized_target).await; + } + + self.inner.send_to(buf, &normalized_target).await + } + + fn local_addr(&self) -> ConnResult { + Ok(self.inner.local_addr()) + } + + fn remote_addr(&self) -> Option { + None + } + async fn close(&self) -> ConnResult<()> { + self.inner.close(); + + Ok(()) + } +} + +#[inline(always)] +/// Create a buffer of appropriate size to fit both a packet with max RECEIVE_MTU and the +/// additional metadata used for muxing. +fn make_buffer() -> Vec { + // The 4 extra bytes are used to encode the length of the data and address respectively. + // See [`write_packet`] for details. + vec![0u8; RECEIVE_MTU + MAX_ADDR_SIZE + 2 + 2] +} diff --git a/reserved/ice/src/udp_mux/udp_mux_test.rs b/reserved/ice/src/udp_mux/udp_mux_test.rs new file mode 100644 index 0000000..f493f4c --- /dev/null +++ b/reserved/ice/src/udp_mux/udp_mux_test.rs @@ -0,0 +1,292 @@ +use std::convert::TryInto; +use std::io; +use std::time::Duration; + +use rand::{thread_rng, Rng}; +use sha1::{Digest, Sha1}; +use stun::message::{Message, BINDING_REQUEST}; +use tokio::net::UdpSocket; +use tokio::time::{sleep, timeout}; + +use super::*; +use crate::error::Result; + +#[derive(Debug, Copy, Clone)] +enum Network { + Ipv4, + Ipv6, +} + +impl Network { + /// Bind the UDP socket for the "remote". + async fn bind(self) -> io::Result { + match self { + Network::Ipv4 => UdpSocket::bind("0.0.0.0:0").await, + Network::Ipv6 => UdpSocket::bind("[::]:0").await, + } + } + + /// Connnect ip from the "remote". + fn connect_ip(self, port: u16) -> String { + match self { + Network::Ipv4 => format!("127.0.0.1:{port}"), + Network::Ipv6 => format!("[::1]:{port}"), + } + } +} + +const TIMEOUT: Duration = Duration::from_secs(60); + +#[tokio::test] +async fn test_udp_mux() -> Result<()> { + use std::io::Write; + env_logger::Builder::from_default_env() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .init(); + + // TODO: Support IPv6 dual stack. This works Linux and macOS, but not Windows. + #[cfg(all(unix, target_pointer_width = "64"))] + let udp_socket = UdpSocket::bind((std::net::Ipv6Addr::UNSPECIFIED, 0)).await?; + + #[cfg(any(not(unix), not(target_pointer_width = "64")))] + let udp_socket = UdpSocket::bind((std::net::Ipv4Addr::UNSPECIFIED, 0)).await?; + + let addr = udp_socket.local_addr()?; + log::info!("Listening on {}", addr); + + let udp_mux = UDPMuxDefault::new(UDPMuxParams::new(udp_socket)); + let udp_mux_dyn = Arc::clone(&udp_mux) as Arc; + + let udp_mux_dyn_1 = Arc::clone(&udp_mux_dyn); + let h1 = tokio::spawn(async move { + timeout( + TIMEOUT, + test_mux_connection(Arc::clone(&udp_mux_dyn_1), "ufrag1", addr, Network::Ipv4), + ) + .await + }); + + let udp_mux_dyn_2 = Arc::clone(&udp_mux_dyn); + let h2 = tokio::spawn(async move { + timeout( + TIMEOUT, + test_mux_connection(Arc::clone(&udp_mux_dyn_2), "ufrag2", addr, Network::Ipv4), + ) + .await + }); + + let all_results; + + #[cfg(all(unix, target_pointer_width = "64"))] + { + // TODO: Support IPv6 dual stack. This works Linux and macOS, but not Windows. + let udp_mux_dyn_3 = Arc::clone(&udp_mux_dyn); + let h3 = tokio::spawn(async move { + timeout( + TIMEOUT, + test_mux_connection(Arc::clone(&udp_mux_dyn_3), "ufrag3", addr, Network::Ipv6), + ) + .await + }); + + let (r1, r2, r3) = tokio::join!(h1, h2, h3); + all_results = [r1, r2, r3]; + } + + #[cfg(any(not(unix), not(target_pointer_width = "64")))] + { + let (r1, r2) = tokio::join!(h1, h2); + all_results = [r1, r2]; + } + + for timeout_result in &all_results { + // Timeout error + match timeout_result { + Err(timeout_err) => { + panic!("Mux test timedout: {timeout_err:?}"); + } + + // Join error + Ok(join_result) => match join_result { + Err(err) => { + panic!("Mux test failed with join error: {err:?}"); + } + // Actual error + Ok(mux_result) => { + if let Err(err) = mux_result { + panic!("Mux test failed with error: {err:?}"); + } + } + }, + } + } + + let timeout = all_results.iter().find_map(|r| r.as_ref().err()); + assert!( + timeout.is_none(), + "At least one of the muxed tasks timedout {all_results:?}" + ); + + let res = udp_mux.close().await; + assert!(res.is_ok()); + let res = udp_mux.get_conn("failurefrag").await; + + assert!( + res.is_err(), + "Getting connections after UDPMuxDefault is closed should fail" + ); + + Ok(()) +} + +async fn test_mux_connection( + mux: Arc, + ufrag: &str, + listener_addr: SocketAddr, + network: Network, +) -> Result<()> { + let conn = mux.get_conn(ufrag).await?; + // FIXME: Cleanup + + let connect_addr = network + .connect_ip(listener_addr.port()) + .parse::() + .unwrap(); + + let remote_connection = Arc::new(network.bind().await?); + log::info!("Bound for ufrag: {}", ufrag); + remote_connection.connect(connect_addr).await?; + log::info!("Connected to {} for ufrag: {}", connect_addr, ufrag); + log::info!( + "Testing muxing from {} over {}", + remote_connection.local_addr().unwrap(), + listener_addr + ); + + // These bytes should be dropped + remote_connection.send("Droppped bytes".as_bytes()).await?; + + sleep(Duration::from_millis(1)).await; + + let stun_msg = { + let mut m = Message { + typ: BINDING_REQUEST, + ..Message::default() + }; + + m.add(ATTR_USERNAME, format!("{ufrag}:otherufrag").as_bytes()); + + m.marshal_binary().unwrap() + }; + + let remote_connection_addr = remote_connection.local_addr()?; + + conn.send_to(&stun_msg, remote_connection_addr).await?; + + let mut buffer = vec![0u8; RECEIVE_MTU]; + let len = remote_connection.recv(&mut buffer).await?; + assert_eq!(buffer[..len], stun_msg); + + const TARGET_SIZE: usize = 1024 * 1024; + + // Read on the muxed side + let conn_2 = Arc::clone(&conn); + let mux_handle = tokio::spawn(async move { + let conn = conn_2; + + let mut buffer = vec![0u8; RECEIVE_MTU]; + let mut next_sequence = 0; + let mut read = 0; + + while read < TARGET_SIZE { + let (n, _) = conn + .recv_from(&mut buffer) + .await + .expect("recv_from should not error"); + assert_eq!(n, RECEIVE_MTU); + + verify_packet(&buffer[..n], next_sequence); + + conn.send_to(&buffer[..n], remote_connection_addr) + .await + .expect("Failed to write to muxxed connection"); + + read += n; + log::debug!("Muxxed read {}, sequence: {}", read, next_sequence); + next_sequence += 1; + } + }); + + let remote_connection_2 = Arc::clone(&remote_connection); + let remote_handle = tokio::spawn(async move { + let remote_connection = remote_connection_2; + let mut buffer = vec![0u8; RECEIVE_MTU]; + let mut next_sequence = 0; + let mut read = 0; + + while read < TARGET_SIZE { + let n = remote_connection + .recv(&mut buffer) + .await + .expect("recv_from should not error"); + assert_eq!(n, RECEIVE_MTU); + + verify_packet(&buffer[..n], next_sequence); + read += n; + log::debug!("Remote read {}, sequence: {}", read, next_sequence); + next_sequence += 1; + } + }); + + let mut sequence: u32 = 0; + let mut written = 0; + let mut buffer = vec![0u8; RECEIVE_MTU]; + while written < TARGET_SIZE { + thread_rng().fill(&mut buffer[24..]); + + let hash = sha1_hash(&buffer[24..]); + buffer[4..24].copy_from_slice(&hash); + buffer[0..4].copy_from_slice(&sequence.to_le_bytes()); + + let len = remote_connection.send(&buffer).await?; + + written += len; + log::debug!("Data written {}, sequence: {}", written, sequence); + sequence += 1; + + sleep(Duration::from_millis(1)).await; + } + + let (r1, r2) = tokio::join!(mux_handle, remote_handle); + assert!(r1.is_ok() && r2.is_ok()); + + let res = conn.close().await; + assert!(res.is_ok(), "Failed to close Conn: {res:?}"); + + Ok(()) +} + +fn verify_packet(buffer: &[u8], next_sequence: u32) { + let read_sequence = u32::from_le_bytes(buffer[0..4].try_into().unwrap()); + assert_eq!(read_sequence, next_sequence); + + let hash = sha1_hash(&buffer[24..]); + assert_eq!(hash, buffer[4..24]); +} + +fn sha1_hash(buffer: &[u8]) -> Vec { + let mut hasher = Sha1::new(); + hasher.update(&buffer[24..]); + + hasher.finalize().to_vec() +} diff --git a/reserved/ice/src/udp_network.rs b/reserved/ice/src/udp_network.rs new file mode 100644 index 0000000..fecb8aa --- /dev/null +++ b/reserved/ice/src/udp_network.rs @@ -0,0 +1,116 @@ +use std::sync::Arc; + +use super::udp_mux::UDPMux; +use super::Error; + +#[derive(Default, Clone)] +pub struct EphemeralUDP { + port_min: u16, + port_max: u16, +} + +impl EphemeralUDP { + pub fn new(port_min: u16, port_max: u16) -> Result { + let mut s = Self::default(); + s.set_ports(port_min, port_max)?; + + Ok(s) + } + + pub fn port_min(&self) -> u16 { + self.port_min + } + + pub fn port_max(&self) -> u16 { + self.port_max + } + + pub fn set_ports(&mut self, port_min: u16, port_max: u16) -> Result<(), Error> { + if port_max < port_min { + return Err(Error::ErrPort); + } + + self.port_min = port_min; + self.port_max = port_max; + + Ok(()) + } +} + +/// Configuration for the underlying UDP network stack. +/// There are two ways to configure this Ephemeral and Muxed. +/// +/// **Ephemeral mode** +/// +/// In Ephemeral mode sockets are created and bound to random ports during ICE +/// gathering. The ports to use can be restricted by setting [`EphemeralUDP::port_min`] and +/// [`EphemeralEphemeralUDP::port_max`] in which case only ports in this range will be used. +/// +/// **Muxed** +/// +/// In muxed mode a single UDP socket is used and all connections are muxed over this single socket. +/// +#[derive(Clone)] +pub enum UDPNetwork { + Ephemeral(EphemeralUDP), + Muxed(Arc), +} + +impl Default for UDPNetwork { + fn default() -> Self { + Self::Ephemeral(Default::default()) + } +} + +impl UDPNetwork { + fn is_ephemeral(&self) -> bool { + matches!(self, Self::Ephemeral(_)) + } + + fn is_muxed(&self) -> bool { + matches!(self, Self::Muxed(_)) + } +} + +#[cfg(test)] +mod test { + use super::EphemeralUDP; + + #[test] + fn test_ephemeral_udp_constructor() { + assert!( + EphemeralUDP::new(3000, 2999).is_err(), + "EphemeralUDP should not allow invalid port range" + ); + + let e = EphemeralUDP::default(); + assert_eq!(e.port_min(), 0, "EphemeralUDP should default port_min to 0"); + assert_eq!(e.port_max(), 0, "EphemeralUDP should default port_max to 0"); + } + + #[test] + fn test_ephemeral_udp_set_ports() { + let mut e = EphemeralUDP::default(); + + assert!( + e.set_ports(3000, 2999).is_err(), + "EphemeralUDP should not allow invalid port range" + ); + + assert!( + e.set_ports(6000, 6001).is_ok(), + "EphemeralUDP::set_ports should allow valid port range" + ); + + assert_eq!( + e.port_min(), + 6000, + "Ports set with `EphemeralUDP::set_ports` should be reflected" + ); + assert_eq!( + e.port_max(), + 6001, + "Ports set with `EphemeralUDP::set_ports` should be reflected" + ); + } +} diff --git a/reserved/ice/src/url/mod.rs b/reserved/ice/src/url/mod.rs new file mode 100644 index 0000000..d55fd48 --- /dev/null +++ b/reserved/ice/src/url/mod.rs @@ -0,0 +1,266 @@ +#[cfg(test)] +mod url_test; + +use std::borrow::Cow; +use std::convert::From; +use std::fmt; + +use crate::error::*; + +/// The type of server used in the ice.URL structure. +#[derive(PartialEq, Eq, Debug, Copy, Clone)] +pub enum SchemeType { + /// The URL represents a STUN server. + Stun, + + /// The URL represents a STUNS (secure) server. + Stuns, + + /// The URL represents a TURN server. + Turn, + + /// The URL represents a TURNS (secure) server. + Turns, + + /// Default public constant to use for "enum" like struct comparisons when no value was defined. + Unknown, +} + +impl Default for SchemeType { + fn default() -> Self { + Self::Unknown + } +} + +impl From<&str> for SchemeType { + /// Defines a procedure for creating a new `SchemeType` from a raw + /// string naming the scheme type. + fn from(raw: &str) -> Self { + match raw { + "stun" => Self::Stun, + "stuns" => Self::Stuns, + "turn" => Self::Turn, + "turns" => Self::Turns, + _ => Self::Unknown, + } + } +} + +impl fmt::Display for SchemeType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + SchemeType::Stun => "stun", + SchemeType::Stuns => "stuns", + SchemeType::Turn => "turn", + SchemeType::Turns => "turns", + SchemeType::Unknown => "unknown", + }; + write!(f, "{s}") + } +} + +/// The transport protocol type that is used in the `ice::url::Url` structure. +#[derive(PartialEq, Eq, Debug, Copy, Clone)] +pub enum ProtoType { + /// The URL uses a UDP transport. + Udp, + + /// The URL uses a TCP transport. + Tcp, + + Unknown, +} + +impl Default for ProtoType { + fn default() -> Self { + Self::Udp + } +} + +// defines a procedure for creating a new ProtoType from a raw +// string naming the transport protocol type. +impl From<&str> for ProtoType { + // NewSchemeType defines a procedure for creating a new SchemeType from a raw + // string naming the scheme type. + fn from(raw: &str) -> Self { + match raw { + "udp" => Self::Udp, + "tcp" => Self::Tcp, + _ => Self::Unknown, + } + } +} + +impl fmt::Display for ProtoType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + Self::Udp => "udp", + Self::Tcp => "tcp", + Self::Unknown => "unknown", + }; + write!(f, "{s}") + } +} + +/// Represents a STUN (rfc7064) or TURN (rfc7065) URL. +#[derive(Debug, Clone, Default)] +pub struct Url { + pub scheme: SchemeType, + pub host: String, + pub port: u16, + pub username: String, + pub password: String, + pub proto: ProtoType, +} + +impl fmt::Display for Url { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let host = if self.host.contains("::") { + "[".to_owned() + self.host.as_str() + "]" + } else { + self.host.clone() + }; + if self.scheme == SchemeType::Turn || self.scheme == SchemeType::Turns { + write!( + f, + "{}:{}:{}?transport={}", + self.scheme, host, self.port, self.proto + ) + } else { + write!(f, "{}:{}:{}", self.scheme, host, self.port) + } + } +} + +impl Url { + /// Parses a STUN or TURN urls following the ABNF syntax described in + /// [IETF rfc-7064](https://tools.ietf.org/html/rfc7064) and + /// [IETF rfc-7065](https://tools.ietf.org/html/rfc7065) respectively. + pub fn parse_url(raw: &str) -> Result { + // work around for url crate + if raw.contains("//") { + return Err(Error::ErrInvalidUrl); + } + + let mut s = raw.to_string(); + let pos = raw.find(':'); + if let Some(p) = pos { + s.replace_range(p..=p, "://"); + } else { + return Err(Error::ErrSchemeType); + } + + let raw_parts = url::Url::parse(&s)?; + + let scheme = raw_parts.scheme().into(); + + let host = if let Some(host) = raw_parts.host_str() { + host.trim() + .trim_start_matches('[') + .trim_end_matches(']') + .to_owned() + } else { + return Err(Error::ErrHost); + }; + + let port = if let Some(port) = raw_parts.port() { + port + } else if scheme == SchemeType::Stun || scheme == SchemeType::Turn { + 3478 + } else { + 5349 + }; + + let mut q_args = raw_parts.query_pairs(); + let proto = match scheme { + SchemeType::Stun => { + if q_args.count() > 0 { + return Err(Error::ErrStunQuery); + } + ProtoType::Udp + } + SchemeType::Stuns => { + if q_args.count() > 0 { + return Err(Error::ErrStunQuery); + } + ProtoType::Tcp + } + SchemeType::Turn => { + if q_args.count() > 1 { + return Err(Error::ErrInvalidQuery); + } + if let Some((key, value)) = q_args.next() { + if key == Cow::Borrowed("transport") { + let proto: ProtoType = value.as_ref().into(); + if proto == ProtoType::Unknown { + return Err(Error::ErrProtoType); + } + proto + } else { + return Err(Error::ErrInvalidQuery); + } + } else { + ProtoType::Udp + } + } + SchemeType::Turns => { + if q_args.count() > 1 { + return Err(Error::ErrInvalidQuery); + } + if let Some((key, value)) = q_args.next() { + if key == Cow::Borrowed("transport") { + let proto: ProtoType = value.as_ref().into(); + if proto == ProtoType::Unknown { + return Err(Error::ErrProtoType); + } + proto + } else { + return Err(Error::ErrInvalidQuery); + } + } else { + ProtoType::Tcp + } + } + SchemeType::Unknown => { + return Err(Error::ErrSchemeType); + } + }; + + Ok(Self { + scheme, + host, + port, + username: "".to_owned(), + password: "".to_owned(), + proto, + }) + } + + /* + fn parse_proto(raw:&str) ->Result { + let qArgs= raw.split('='); + if qArgs.len() != 2 { + return Err(Error::ErrInvalidQuery.into()); + } + + var proto ProtoType + if rawProto := qArgs.Get("transport"); rawProto != "" { + if proto = NewProtoType(rawProto); proto == ProtoType(0) { + return ProtoType(Unknown), ErrProtoType + } + return proto, nil + } + + if len(qArgs) > 0 { + return ProtoType(Unknown), ErrInvalidQuery + } + + return proto, nil + }*/ + + /// Returns whether the this URL's scheme describes secure scheme or not. + #[must_use] + pub fn is_secure(&self) -> bool { + self.scheme == SchemeType::Stuns || self.scheme == SchemeType::Turns + } +} diff --git a/reserved/ice/src/url/url_test.rs b/reserved/ice/src/url/url_test.rs new file mode 100644 index 0000000..acbf727 --- /dev/null +++ b/reserved/ice/src/url/url_test.rs @@ -0,0 +1,142 @@ +use super::*; + +#[test] +fn test_parse_url_success() -> Result<()> { + let tests = vec![ + ( + "stun:google.de", + "stun:google.de:3478", + SchemeType::Stun, + false, + "google.de", + 3478, + ProtoType::Udp, + ), + ( + "stun:google.de:1234", + "stun:google.de:1234", + SchemeType::Stun, + false, + "google.de", + 1234, + ProtoType::Udp, + ), + ( + "stuns:google.de", + "stuns:google.de:5349", + SchemeType::Stuns, + true, + "google.de", + 5349, + ProtoType::Tcp, + ), + ( + "stun:[::1]:123", + "stun:[::1]:123", + SchemeType::Stun, + false, + "::1", + 123, + ProtoType::Udp, + ), + ( + "turn:google.de", + "turn:google.de:3478?transport=udp", + SchemeType::Turn, + false, + "google.de", + 3478, + ProtoType::Udp, + ), + ( + "turns:google.de", + "turns:google.de:5349?transport=tcp", + SchemeType::Turns, + true, + "google.de", + 5349, + ProtoType::Tcp, + ), + ( + "turn:google.de?transport=udp", + "turn:google.de:3478?transport=udp", + SchemeType::Turn, + false, + "google.de", + 3478, + ProtoType::Udp, + ), + ( + "turns:google.de?transport=tcp", + "turns:google.de:5349?transport=tcp", + SchemeType::Turns, + true, + "google.de", + 5349, + ProtoType::Tcp, + ), + ]; + + for ( + raw_url, + expected_url_string, + expected_scheme, + expected_secure, + expected_host, + expected_port, + expected_proto, + ) in tests + { + let url = Url::parse_url(raw_url)?; + + assert_eq!(url.scheme, expected_scheme, "testCase: {raw_url:?}"); + assert_eq!( + expected_url_string, + url.to_string(), + "testCase: {raw_url:?}" + ); + assert_eq!(url.is_secure(), expected_secure, "testCase: {raw_url:?}"); + assert_eq!(url.host, expected_host, "testCase: {raw_url:?}"); + assert_eq!(url.port, expected_port, "testCase: {raw_url:?}"); + assert_eq!(url.proto, expected_proto, "testCase: {raw_url:?}"); + } + + Ok(()) +} + +#[test] +fn test_parse_url_failure() -> Result<()> { + let tests = vec![ + ("", Error::ErrSchemeType), + (":::", Error::ErrUrlParse), + ("stun:[::1]:123:", Error::ErrPort), + ("stun:[::1]:123a", Error::ErrPort), + ("google.de", Error::ErrSchemeType), + ("stun:", Error::ErrHost), + ("stun:google.de:abc", Error::ErrPort), + ("stun:google.de?transport=udp", Error::ErrStunQuery), + ("stuns:google.de?transport=udp", Error::ErrStunQuery), + ("turn:google.de?trans=udp", Error::ErrInvalidQuery), + ("turns:google.de?trans=udp", Error::ErrInvalidQuery), + ( + "turns:google.de?transport=udp&another=1", + Error::ErrInvalidQuery, + ), + ("turn:google.de?transport=ip", Error::ErrProtoType), + ]; + + for (raw_url, expected_err) in tests { + let result = Url::parse_url(raw_url); + if let Err(err) = result { + assert_eq!( + err.to_string(), + expected_err.to_string(), + "testCase: '{raw_url}', expected err '{expected_err}', but got err '{err}'" + ); + } else { + panic!("expected error, but got ok"); + } + } + + Ok(()) +} diff --git a/reserved/ice/src/use_candidate/mod.rs b/reserved/ice/src/use_candidate/mod.rs new file mode 100644 index 0000000..8bb0d47 --- /dev/null +++ b/reserved/ice/src/use_candidate/mod.rs @@ -0,0 +1,31 @@ +#[cfg(test)] +mod use_candidate_test; + +use stun::attributes::ATTR_USE_CANDIDATE; +use stun::message::*; + +/// Represents USE-CANDIDATE attribute. +#[derive(Default)] +pub struct UseCandidateAttr; + +impl Setter for UseCandidateAttr { + /// Adds USE-CANDIDATE attribute to message. + fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { + m.add(ATTR_USE_CANDIDATE, &[]); + Ok(()) + } +} + +impl UseCandidateAttr { + #[must_use] + pub const fn new() -> Self { + Self + } + + /// Returns true if USE-CANDIDATE attribute is set. + #[must_use] + pub fn is_set(m: &Message) -> bool { + let result = m.get(ATTR_USE_CANDIDATE); + result.is_ok() + } +} diff --git a/reserved/ice/src/use_candidate/use_candidate_test.rs b/reserved/ice/src/use_candidate/use_candidate_test.rs new file mode 100644 index 0000000..671a754 --- /dev/null +++ b/reserved/ice/src/use_candidate/use_candidate_test.rs @@ -0,0 +1,19 @@ +use stun::message::BINDING_REQUEST; + +use super::*; +use crate::error::Result; + +#[test] +fn test_use_candidate_attr_add_to() -> Result<()> { + let mut m = Message::new(); + assert!(!UseCandidateAttr::is_set(&m), "should not be set"); + + m.build(&[Box::new(BINDING_REQUEST), Box::new(UseCandidateAttr::new())])?; + + let mut m1 = Message::new(); + m1.write(&m.raw)?; + + assert!(UseCandidateAttr::is_set(&m1), "should be set"); + + Ok(()) +} diff --git a/reserved/ice/src/util/mod.rs b/reserved/ice/src/util/mod.rs new file mode 100644 index 0000000..a44cb09 --- /dev/null +++ b/reserved/ice/src/util/mod.rs @@ -0,0 +1,175 @@ +#[cfg(test)] +mod util_test; + +use std::collections::HashSet; +use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; + +use stun::agent::*; +use stun::attributes::*; +use stun::integrity::*; +use stun::message::*; +use stun::textattrs::*; +use stun::xoraddr::*; +use tokio::time::Duration; +use util::vnet::net::*; +use util::Conn; + +use crate::agent::agent_config::{InterfaceFilterFn, IpFilterFn}; +use crate::error::*; +use crate::network_type::*; + +pub fn create_addr(_network: NetworkType, ip: IpAddr, port: u16) -> SocketAddr { + /*if network.is_tcp(){ + return &net.TCPAddr{IP: ip, Port: port} + default: + return &net.UDPAddr{IP: ip, Port: port} + }*/ + SocketAddr::new(ip, port) +} + +pub fn assert_inbound_username(m: &Message, expected_username: &str) -> Result<()> { + let mut username = Username::new(ATTR_USERNAME, String::new()); + username.get_from(m)?; + + if username.to_string() != expected_username { + return Err(Error::Other(format!( + "{:?} expected({}) actual({})", + Error::ErrMismatchUsername, + expected_username, + username, + ))); + } + + Ok(()) +} + +pub fn assert_inbound_message_integrity(m: &mut Message, key: &[u8]) -> Result<()> { + let message_integrity_attr = MessageIntegrity(key.to_vec()); + Ok(message_integrity_attr.check(m)?) +} + +/// Initiates a stun requests to `server_addr` using conn, reads the response and returns the +/// `XORMappedAddress` returned by the stun server. +/// Adapted from stun v0.2. +pub async fn get_xormapped_addr( + conn: &Arc, + server_addr: SocketAddr, + deadline: Duration, +) -> Result { + let resp = stun_request(conn, server_addr, deadline).await?; + let mut addr = XorMappedAddress::default(); + addr.get_from(&resp)?; + Ok(addr) +} + +const MAX_MESSAGE_SIZE: usize = 1280; + +pub async fn stun_request( + conn: &Arc, + server_addr: SocketAddr, + deadline: Duration, +) -> Result { + let mut request = Message::new(); + request.build(&[Box::new(BINDING_REQUEST), Box::new(TransactionId::new())])?; + + conn.send_to(&request.raw, server_addr).await?; + let mut bs = vec![0_u8; MAX_MESSAGE_SIZE]; + let (n, _) = if deadline > Duration::from_secs(0) { + match tokio::time::timeout(deadline, conn.recv_from(&mut bs)).await { + Ok(result) => match result { + Ok((n, addr)) => (n, addr), + Err(err) => return Err(Error::Other(err.to_string())), + }, + Err(err) => return Err(Error::Other(err.to_string())), + } + } else { + conn.recv_from(&mut bs).await? + }; + + let mut res = Message::new(); + res.raw = bs[..n].to_vec(); + res.decode()?; + + Ok(res) +} + +pub async fn local_interfaces( + vnet: &Arc, + interface_filter: &Option, + ip_filter: &Option, + network_types: &[NetworkType], +) -> HashSet { + let mut ips = HashSet::new(); + let interfaces = vnet.get_interfaces().await; + + let (mut ipv4requested, mut ipv6requested) = (false, false); + for typ in network_types { + if typ.is_ipv4() { + ipv4requested = true; + } + if typ.is_ipv6() { + ipv6requested = true; + } + } + + for iface in interfaces { + if let Some(filter) = interface_filter { + if !filter(iface.name()) { + continue; + } + } + + for ipnet in iface.addrs() { + let ipaddr = ipnet.addr(); + + if !ipaddr.is_loopback() + && ((ipv4requested && ipaddr.is_ipv4()) || (ipv6requested && ipaddr.is_ipv6())) + && ip_filter + .as_ref() + .map(|filter| filter(ipaddr)) + .unwrap_or(true) + { + ips.insert(ipaddr); + } + } + } + + ips +} + +pub async fn listen_udp_in_port_range( + vnet: &Arc, + port_max: u16, + port_min: u16, + laddr: SocketAddr, +) -> Result> { + if laddr.port() != 0 || (port_min == 0 && port_max == 0) { + return Ok(vnet.bind(laddr).await?); + } + let i = if port_min == 0 { 1 } else { port_min }; + let j = if port_max == 0 { 0xFFFF } else { port_max }; + if i > j { + return Err(Error::ErrPort); + } + + let port_start = rand::random::() % (j - i + 1) + i; + let mut port_current = port_start; + loop { + let laddr = SocketAddr::new(laddr.ip(), port_current); + match vnet.bind(laddr).await { + Ok(c) => return Ok(c), + Err(err) => log::debug!("failed to listen {}: {}", laddr, err), + }; + + port_current += 1; + if port_current > j { + port_current = i; + } + if port_current == port_start { + break; + } + } + + Err(Error::ErrPort) +} diff --git a/reserved/ice/src/util/util_test.rs b/reserved/ice/src/util/util_test.rs new file mode 100644 index 0000000..ab0faf9 --- /dev/null +++ b/reserved/ice/src/util/util_test.rs @@ -0,0 +1,10 @@ +use super::*; + +#[tokio::test] +async fn test_local_interfaces() -> Result<()> { + let vnet = Arc::new(Net::new(None)); + let interfaces = vnet.get_interfaces().await; + let ips = local_interfaces(&vnet, &None, &None, &[NetworkType::Udp4, NetworkType::Udp6]).await; + log::info!("interfaces: {:?}, ips: {:?}", interfaces, ips); + Ok(()) +} diff --git a/reserved/rtc-turn/Cargo.toml b/reserved/rtc-turn/Cargo.toml deleted file mode 100644 index 6604dee..0000000 --- a/reserved/rtc-turn/Cargo.toml +++ /dev/null @@ -1,14 +0,0 @@ -[package] -name = "rtc-turn" -version = "0.0.0" -authors = ["Rain Liu "] -edition = "2021" -description = "RTC TURN in Rust" -license = "MIT/Apache-2.0" -documentation = "https://docs.rs/rtc-turn" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/rtc" - -[dependencies] - -[dev-dependencies] diff --git a/rtc-turn/.gitignore b/rtc-turn/.gitignore new file mode 100644 index 0000000..81561ed --- /dev/null +++ b/rtc-turn/.gitignore @@ -0,0 +1,11 @@ +# Generated by Cargo +# will have compiled files and executables +/target/ +/.idea/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk diff --git a/rtc-turn/CHANGELOG.md b/rtc-turn/CHANGELOG.md new file mode 100644 index 0000000..df88930 --- /dev/null +++ b/rtc-turn/CHANGELOG.md @@ -0,0 +1,27 @@ +# webrtc-turn changelog + +## Unreleased + +* [#330 Fix the problem that the UDP port of the server relay is not released](https://github.com/webrtc-rs/webrtc/pull/330) by [@clia](https://github.com/clia). +* Added `alloc_close_notify` config parameter to `ServerConfig` and `Allocation`, to receive notify on allocation close event, with metrics data. + +## v0.6.1 + +* Added `delete_allocations_by_username` method on `Server`. This method provides possibility to manually delete allocation [#263](https://github.com/webrtc-rs/webrtc/pull/263) by [@logist322](https://github.com/logist322). +* Added `get_allocations_info` method on `Server`. This method provides possibility to get information about allocations [#288](https://github.com/webrtc-rs/webrtc/pull/288) by [@logist322](https://github.com/logist322). +* Increased minimum support rust version to `1.60.0`. +* Increased required `webrtc-util` version to `0.7.0`. + + +## v0.6.0 + +* [#15 update deps + loosen some requirements](https://github.com/webrtc-rs/turn/pull/15) by [@melekes](https://github.com/melekes). +* [#11 Fixed spelling of convenience](https://github.com/webrtc-rs/turn/pull/11) by [@Charles-Schleich ](https://github.com/Charles-Schleich). +* Increase min version of `log` dependency to `0.4.16`. [#250 Fix log at ^0.4.16 to make tests compile](https://github.com/webrtc-rs/webrtc/pull/250) by [@k0nserv](https://github.com/k0nserv). +* [#246 Fix warnings on windows](https://github.com/webrtc-rs/webrtc/pull/246) by [@https://github.com/xnorpx](https://github.com/xnorpx). + + +## Prior to 0.6.0 + +Before 0.6.0 there was no changelog, previous changes are sometimes, but not always, available in the [GitHub Releases](https://github.com/webrtc-rs/turn/releases). + diff --git a/rtc-turn/Cargo.toml b/rtc-turn/Cargo.toml new file mode 100644 index 0000000..ea10cfb --- /dev/null +++ b/rtc-turn/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "rtc-turn" +version = "0.0.0" +authors = ["Rain Liu "] +edition = "2021" +description = "RTC TURN in Rust" +license = "MIT/Apache-2.0" +documentation = "https://docs.rs/rtc-turn" +homepage = "https://webrtc.rs" +repository = "https://github.com/webrtc-rs/rtc" + +[dependencies] +shared = { path = "../rtc-shared", package = "rtc-shared", default-features = false, features = [] } +stun = { path = "../rtc-stun", package = "rtc-stun" } + +bytes = "1.4.0" +log = "0.4.16" +base64 = "0.21.2" +rand = "0.8.5" +ring = "0.16.20" +md-5 = "0.10.1" +thiserror = "1.0" + +[dev-dependencies] +env_logger = "0.9.0" +chrono = "0.4.23" +hex = "0.4.3" +clap = "3.2.6" +criterion = "0.4.0" + +[features] +metrics = [] + +[[bench]] +name = "bench" +harness = false + +[[example]] +name = "turn_client_udp" +path = "examples/turn_client_udp.rs" +bench = false diff --git a/rtc-turn/LICENSE-APACHE b/rtc-turn/LICENSE-APACHE new file mode 100644 index 0000000..16fe87b --- /dev/null +++ b/rtc-turn/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/rtc-turn/LICENSE-MIT b/rtc-turn/LICENSE-MIT new file mode 100644 index 0000000..e11d93b --- /dev/null +++ b/rtc-turn/LICENSE-MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 WebRTC.rs + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/rtc-turn/README.md b/rtc-turn/README.md new file mode 100644 index 0000000..15e5853 --- /dev/null +++ b/rtc-turn/README.md @@ -0,0 +1,30 @@ +

+ WebRTC.rs +
+

+

+ + + + + + + + + + + + + + + + + License: MIT/Apache 2.0 + + + Discord + +

+

+ A pure Rust implementation of TURN. Rewrite Pion TURN in Rust +

diff --git a/rtc-turn/benches/bench.rs b/rtc-turn/benches/bench.rs new file mode 100644 index 0000000..81513db --- /dev/null +++ b/rtc-turn/benches/bench.rs @@ -0,0 +1,137 @@ +use std::time::Duration; + +use criterion::{criterion_group, criterion_main, Criterion}; +use stun::attributes::ATTR_DATA; +use stun::message::{Getter, Message, Setter}; +use turn::proto::chandata::ChannelData; +use turn::proto::channum::{ChannelNumber, MIN_CHANNEL_NUMBER}; +use turn::proto::data::Data; +use turn::proto::lifetime::Lifetime; + +fn benchmark_chan_data(c: &mut Criterion) { + { + let buf = [64, 0, 0, 0, 0, 4, 0, 0, 1, 2, 3]; + c.bench_function("BenchmarkIsChannelData", |b| { + b.iter(|| { + assert!(ChannelData::is_channel_data(&buf)); + }) + }); + } + + { + let mut d = ChannelData { + data: vec![1, 2, 3, 4], + number: ChannelNumber(MIN_CHANNEL_NUMBER + 1), + raw: vec![], + }; + c.bench_function("BenchmarkChannelData_Encode", |b| { + b.iter(|| { + d.encode(); + d.reset(); + }) + }); + } + + { + let mut d = ChannelData { + data: vec![1, 2, 3, 4], + number: ChannelNumber(MIN_CHANNEL_NUMBER + 1), + raw: vec![], + }; + d.encode(); + let mut buf = vec![0u8; d.raw.len()]; + buf.copy_from_slice(&d.raw); + c.bench_function("BenchmarkChannelData_Decode", |b| { + b.iter(|| { + d.reset(); + d.raw = buf.clone(); + d.decode().unwrap(); + }) + }); + } +} + +fn benchmark_chan(c: &mut Criterion) { + { + let mut m = Message::new(); + c.bench_function("BenchmarkChannelNumber/AddTo", |b| { + b.iter(|| { + let n = ChannelNumber(12); + n.add_to(&mut m).unwrap(); + m.reset(); + }) + }); + } + + { + let mut m = Message::new(); + let expected = ChannelNumber(12); + expected.add_to(&mut m).unwrap(); + let mut n = ChannelNumber::default(); + c.bench_function("BenchmarkChannelNumber/GetFrom", |b| { + b.iter(|| { + n.get_from(&m).unwrap(); + assert_eq!(n, expected); + }) + }); + } +} + +fn benchmark_data(c: &mut Criterion) { + { + let mut m = Message::new(); + let d = Data(vec![0u8; 10]); + c.bench_function("BenchmarkData/AddTo", |b| { + b.iter(|| { + d.add_to(&mut m).unwrap(); + m.reset(); + }) + }); + } + + { + let mut m = Message::new(); + let d = Data(vec![0u8; 10]); + c.bench_function("BenchmarkData/AddToRaw", |b| { + b.iter(|| { + m.add(ATTR_DATA, &d.0); + m.reset(); + }) + }); + } +} + +fn benchmark_lifetime(c: &mut Criterion) { + { + let mut m = Message::new(); + let l = Lifetime(Duration::from_secs(1)); + c.bench_function("BenchmarkLifetime/AddTo", |b| { + b.iter(|| { + l.add_to(&mut m).unwrap(); + m.reset(); + }) + }); + } + + { + let mut m = Message::new(); + let expected = Lifetime(Duration::from_secs(60)); + expected.add_to(&mut m).unwrap(); + let mut l = Lifetime::default(); + c.bench_function("BenchmarkLifetime/GetFrom", |b| { + b.iter(|| { + l.get_from(&m).unwrap(); + assert_eq!(l, expected); + }) + }); + } +} + +criterion_group!( + benches, + benchmark_chan_data, + benchmark_chan, + benchmark_data, + benchmark_lifetime +); +criterion_main!(benches); diff --git a/rtc-turn/codecov.yml b/rtc-turn/codecov.yml new file mode 100644 index 0000000..bf7afa1 --- /dev/null +++ b/rtc-turn/codecov.yml @@ -0,0 +1,23 @@ +codecov: + require_ci_to_pass: yes + max_report_age: off + token: 640e45ed-ce83-43e1-9eee-473aa65dc136 + +coverage: + precision: 2 + round: down + range: 50..90 + status: + project: + default: + enabled: no + threshold: 0.2 + if_not_found: success + patch: + default: + enabled: no + if_not_found: success + changes: + default: + enabled: no + if_not_found: success diff --git a/rtc-turn/doc/webrtc.rs.png b/rtc-turn/doc/webrtc.rs.png new file mode 100644 index 0000000..7bf0dda Binary files /dev/null and b/rtc-turn/doc/webrtc.rs.png differ diff --git a/rtc-turn/examples/turn_client_udp.rs b/rtc-turn/examples/turn_client_udp.rs new file mode 100644 index 0000000..0ef701c --- /dev/null +++ b/rtc-turn/examples/turn_client_udp.rs @@ -0,0 +1,200 @@ +/* +use std::sync::Arc; + +use clap::{App, AppSettings, Arg}; +use tokio::net::UdpSocket; +use tokio::time::Duration; +use turn::client::*; +use turn::Error; +use util::Conn; + +// RUST_LOG=trace cargo run --color=always --package turn --example turn_client_udp -- --host 0.0.0.0 --user user=pass --ping + +#[tokio::main] +async fn main() -> Result<(), Error> { + env_logger::init(); + + let mut app = App::new("TURN Client UDP") + .version("0.1.0") + .author("Rain Liu ") + .about("An example of TURN Client UDP") + .setting(AppSettings::DeriveDisplayOrder) + .setting(AppSettings::SubcommandsNegateReqs) + .arg( + Arg::with_name("FULLHELP") + .help("Prints more detailed help information") + .long("fullhelp"), + ) + .arg( + Arg::with_name("host") + .required_unless("FULLHELP") + .takes_value(true) + .long("host") + .help("TURN Server name."), + ) + .arg( + Arg::with_name("user") + .required_unless("FULLHELP") + .takes_value(true) + .long("user") + .help("A pair of username and password (e.g. \"user=pass\")"), + ) + .arg( + Arg::with_name("realm") + .default_value("webrtc.rs") + .takes_value(true) + .long("realm") + .help("Realm (defaults to \"webrtc.rs\")"), + ) + .arg( + Arg::with_name("port") + .takes_value(true) + .default_value("3478") + .long("port") + .help("Listening port."), + ) + .arg( + Arg::with_name("ping") + .long("ping") + .takes_value(false) + .help("Run ping test"), + ); + + let matches = app.clone().get_matches(); + + if matches.is_present("FULLHELP") { + app.print_long_help().unwrap(); + std::process::exit(0); + } + + let host = matches.value_of("host").unwrap(); + let port = matches.value_of("port").unwrap(); + let user = matches.value_of("user").unwrap(); + let cred: Vec<&str> = user.splitn(2, '=').collect(); + let ping = matches.is_present("ping"); + let realm = matches.value_of("realm").unwrap(); + + // TURN client won't create a local listening socket by itself. + let conn = UdpSocket::bind("0.0.0.0:0").await?; + + let turn_server_addr = format!("{host}:{port}"); + + let cfg = ClientConfig { + stun_serv_addr: turn_server_addr.clone(), + turn_serv_addr: turn_server_addr, + username: cred[0].to_string(), + password: cred[1].to_string(), + realm: realm.to_string(), + software: String::new(), + rto_in_ms: 0, + conn: Arc::new(conn), + vnet: None, + }; + + let client = Client::new(cfg).await?; + + // Start listening on the conn provided. + client.listen().await?; + + // Allocate a relay socket on the TURN server. On success, it + // will return a net.PacketConn which represents the remote + // socket. + let relay_conn = client.allocate().await?; + + // The relayConn's local address is actually the transport + // address assigned on the TURN server. + println!("relayed-address={}", relay_conn.local_addr()?); + + // If you provided `-ping`, perform a ping test agaist the + // relayConn we have just allocated. + if ping { + do_ping_test(&client, relay_conn).await?; + } + + client.close().await?; + + Ok(()) +} + +async fn do_ping_test( + client: &Client, + relay_conn: impl Conn + std::marker::Send + std::marker::Sync + 'static, +) -> Result<(), Error> { + // Send BindingRequest to learn our external IP + let mapped_addr = client.send_binding_request().await?; + + // Set up pinger socket (pingerConn) + //println!("bind..."); + let pinger_conn_tx = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); + + // Punch a UDP hole for the relay_conn by sending a data to the mapped_addr. + // This will trigger a TURN client to generate a permission request to the + // TURN server. After this, packets from the IP address will be accepted by + // the TURN server. + //println!("relay_conn send hello to mapped_addr {}", mapped_addr); + relay_conn.send_to("Hello".as_bytes(), mapped_addr).await?; + let relay_addr = relay_conn.local_addr()?; + + let pinger_conn_rx = Arc::clone(&pinger_conn_tx); + + // Start read-loop on pingerConn + tokio::spawn(async move { + let mut buf = vec![0u8; 1500]; + loop { + let (n, from) = match pinger_conn_rx.recv_from(&mut buf).await { + Ok((n, from)) => (n, from), + Err(_) => break, + }; + + let msg = match String::from_utf8(buf[..n].to_vec()) { + Ok(msg) => msg, + Err(_) => break, + }; + + println!("pingerConn read-loop: {msg} from {from}"); + /*if sentAt, pingerErr := time.Parse(time.RFC3339Nano, msg); pingerErr == nil { + rtt := time.Since(sentAt) + log.Printf("%d bytes from from %s time=%d ms\n", n, from.String(), int(rtt.Seconds()*1000)) + }*/ + } + }); + + // Start read-loop on relay_conn + tokio::spawn(async move { + let mut buf = vec![0u8; 1500]; + loop { + let (n, from) = match relay_conn.recv_from(&mut buf).await { + Err(_) => break, + Ok((n, from)) => (n, from), + }; + + println!("relay_conn read-loop: {:?} from {}", &buf[..n], from); + + // Echo back + if relay_conn.send_to(&buf[..n], from).await.is_err() { + break; + } + } + }); + + tokio::time::sleep(Duration::from_millis(500)).await; + + /*println!( + "pinger_conn_tx send 10 packets to relay addr {}...", + relay_addr + );*/ + // Send 10 packets from relay_conn to the echo server + for _ in 0..2 { + let msg = "12345678910".to_owned(); //format!("{:?}", tokio::time::Instant::now()); + println!("sending msg={} with size={}", msg, msg.as_bytes().len()); + pinger_conn_tx.send_to(msg.as_bytes(), relay_addr).await?; + + // For simplicity, this example does not wait for the pong (reply). + // Instead, sleep 1 second. + tokio::time::sleep(Duration::from_secs(1)).await; + } + + Ok(()) +} +*/ +fn main() {} diff --git a/rtc-turn/src/client/binding.rs b/rtc-turn/src/client/binding.rs new file mode 100644 index 0000000..8d86ebb --- /dev/null +++ b/rtc-turn/src/client/binding.rs @@ -0,0 +1,135 @@ +#[cfg(test)] +mod binding_test; + +use std::collections::HashMap; +use std::net::SocketAddr; +use std::time::Instant; + +// Chanel number: +// 0x4000 through 0x7FFF: These values are the allowed channel +// numbers (16,383 possible values). +const MIN_CHANNEL_NUMBER: u16 = 0x4000; +const MAX_CHANNEL_NUMBER: u16 = 0x7fff; + +#[derive(Copy, Clone, Debug, PartialEq)] +pub(crate) enum BindingState { + Idle, + Request, + Ready, + Refresh, + Failed, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub(crate) struct Binding { + pub(crate) number: u16, + pub(crate) st: BindingState, + pub(crate) addr: SocketAddr, + pub(crate) refreshed_at: Instant, +} + +impl Binding { + pub(crate) fn set_state(&mut self, state: BindingState) { + //atomic.StoreInt32((*int32)(&b.st), int32(state)) + self.st = state; + } + + pub(crate) fn state(&self) -> BindingState { + //return BindingState(atomic.LoadInt32((*int32)(&b.st))) + self.st + } + + pub(crate) fn set_refreshed_at(&mut self, at: Instant) { + self.refreshed_at = at; + } + + pub(crate) fn refreshed_at(&self) -> Instant { + self.refreshed_at + } +} +// Thread-safe Binding map +#[derive(Default)] +pub(crate) struct BindingManager { + chan_map: HashMap, + addr_map: HashMap, + next: u16, +} + +impl BindingManager { + pub(crate) fn new() -> Self { + BindingManager { + chan_map: HashMap::new(), + addr_map: HashMap::new(), + next: MIN_CHANNEL_NUMBER, + } + } + + pub(crate) fn assign_channel_number(&mut self) -> u16 { + let n = self.next; + if self.next == MAX_CHANNEL_NUMBER { + self.next = MIN_CHANNEL_NUMBER; + } else { + self.next += 1; + } + n + } + + pub(crate) fn create(&mut self, addr: SocketAddr) -> Option<&Binding> { + let b = Binding { + number: self.assign_channel_number(), + st: BindingState::Idle, + addr, + refreshed_at: Instant::now(), + }; + + self.chan_map.insert(b.number, b.addr.to_string()); + self.addr_map.insert(b.addr.to_string(), b); + self.addr_map.get(&addr.to_string()) + } + + pub(crate) fn find_by_addr(&self, addr: &SocketAddr) -> Option<&Binding> { + self.addr_map.get(&addr.to_string()) + } + + pub(crate) fn get_by_addr(&mut self, addr: &SocketAddr) -> Option<&mut Binding> { + self.addr_map.get_mut(&addr.to_string()) + } + + pub(crate) fn find_by_number(&self, number: u16) -> Option<&Binding> { + if let Some(s) = self.chan_map.get(&number) { + self.addr_map.get(s) + } else { + None + } + } + + pub(crate) fn get_by_number(&mut self, number: u16) -> Option<&mut Binding> { + if let Some(s) = self.chan_map.get(&number) { + self.addr_map.get_mut(s) + } else { + None + } + } + + pub(crate) fn delete_by_addr(&mut self, addr: &SocketAddr) -> bool { + if let Some(b) = self.addr_map.remove(&addr.to_string()) { + self.chan_map.remove(&b.number); + true + } else { + false + } + } + + pub(crate) fn delete_by_number(&mut self, number: u16) -> bool { + if let Some(s) = self.chan_map.remove(&number) { + self.addr_map.remove(&s); + true + } else { + false + } + } + + pub(crate) fn size(&self) -> usize { + self.addr_map.len() + } +} diff --git a/rtc-turn/src/client/binding/binding_test.rs b/rtc-turn/src/client/binding/binding_test.rs new file mode 100644 index 0000000..d8bae68 --- /dev/null +++ b/rtc-turn/src/client/binding/binding_test.rs @@ -0,0 +1,83 @@ +use std::net::{Ipv4Addr, SocketAddrV4}; + +use super::*; +use crate::error::Result; + +#[test] +fn test_binding_manager_number_assignment() -> Result<()> { + let mut m = BindingManager::new(); + let mut n: u16; + for i in 0..10 { + n = m.assign_channel_number(); + assert_eq!(MIN_CHANNEL_NUMBER + i, n, "should match"); + } + + m.next = 0x7ff0; + for i in 0..16 { + n = m.assign_channel_number(); + assert_eq!(0x7ff0 + i, n, "should match"); + } + // back to min + n = m.assign_channel_number(); + assert_eq!(MIN_CHANNEL_NUMBER, n, "should match"); + + Ok(()) +} + +#[test] +fn test_binding_manager_method() -> Result<()> { + let lo = Ipv4Addr::new(127, 0, 0, 1); + let count = 100; + let mut m = BindingManager::new(); + for i in 0..count { + let addr = SocketAddr::V4(SocketAddrV4::new(lo, 10000 + i)); + let b0 = { + let b = m.create(addr); + *b.unwrap() + }; + let b1 = m.find_by_addr(&addr); + assert!(b1.is_some(), "should succeed"); + let b2 = m.find_by_number(b0.number); + assert!(b2.is_some(), "should succeed"); + + assert_eq!(b0, *b1.unwrap(), "should match"); + assert_eq!(b0, *b2.unwrap(), "should match"); + } + + assert_eq!(count, m.size() as u16, "should match"); + assert_eq!(count, m.addr_map.len() as u16, "should match"); + + for i in 0..count { + let addr = SocketAddr::V4(SocketAddrV4::new(lo, 10000 + i)); + if i % 2 == 0 { + assert!(m.delete_by_addr(&addr), "should return true"); + } else { + assert!( + m.delete_by_number(MIN_CHANNEL_NUMBER + i), + "should return true" + ); + } + } + + assert_eq!(0, m.size(), "should match"); + assert_eq!(0, m.addr_map.len(), "should match"); + + Ok(()) +} + +#[test] +fn test_binding_manager_failure() -> Result<()> { + let ipv4 = Ipv4Addr::new(127, 0, 0, 1); + let addr = SocketAddr::V4(SocketAddrV4::new(ipv4, 7777)); + let mut m = BindingManager::new(); + let b = m.find_by_addr(&addr); + assert!(b.is_none(), "should fail"); + let b = m.find_by_number(5555); + assert!(b.is_none(), "should fail"); + let ok = m.delete_by_addr(&addr); + assert!(!ok, "should fail"); + let ok = m.delete_by_number(5555); + assert!(!ok, "should fail"); + + Ok(()) +} diff --git a/rtc-turn/src/client/client_test.rs b/rtc-turn/src/client/client_test.rs new file mode 100644 index 0000000..9f08f6b --- /dev/null +++ b/rtc-turn/src/client/client_test.rs @@ -0,0 +1,120 @@ +use tokio::net::UdpSocket; + +use super::*; +use crate::auth::*; + +async fn create_listening_test_client(rto_in_ms: u16) -> Result { + let conn = UdpSocket::bind("0.0.0.0:0").await?; + + let c = Client::new(ClientConfig { + stun_serv_addr: String::new(), + turn_serv_addr: String::new(), + username: String::new(), + password: String::new(), + realm: String::new(), + software: "TEST SOFTWARE".to_owned(), + rto_in_ms, + conn: Arc::new(conn), + vnet: None, + }) + .await?; + + c.listen().await?; + + Ok(c) +} + +async fn create_listening_test_client_with_stun_serv() -> Result { + let conn = UdpSocket::bind("0.0.0.0:0").await?; + + let c = Client::new(ClientConfig { + stun_serv_addr: "stun1.l.google.com:19302".to_owned(), + turn_serv_addr: String::new(), + username: String::new(), + password: String::new(), + realm: String::new(), + software: "TEST SOFTWARE".to_owned(), + rto_in_ms: 0, + conn: Arc::new(conn), + vnet: None, + }) + .await?; + + c.listen().await?; + + Ok(c) +} + +#[tokio::test] +async fn test_client_with_stun_send_binding_request() -> Result<()> { + //env_logger::init(); + + let c = create_listening_test_client_with_stun_serv().await?; + + let resp = c.send_binding_request().await?; + log::debug!("mapped-addr: {}", resp); + { + let ci = c.client_internal.lock().await; + let tm = ci.tr_map.lock().await; + assert_eq!(0, tm.size(), "should be no transaction left"); + } + + c.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_client_with_stun_send_binding_request_to_parallel() -> Result<()> { + env_logger::init(); + + let c1 = create_listening_test_client(0).await?; + let c2 = c1.clone(); + + let (stared_tx, mut started_rx) = mpsc::channel::<()>(1); + let (finished_tx, mut finished_rx) = mpsc::channel::<()>(1); + + let to = lookup_host(true, "stun1.l.google.com:19302").await?; + + tokio::spawn(async move { + drop(stared_tx); + if let Ok(resp) = c2.send_binding_request_to(&to.to_string()).await { + log::debug!("mapped-addr: {}", resp); + } + drop(finished_tx); + }); + + let _ = started_rx.recv().await; + + let resp = c1.send_binding_request_to(&to.to_string()).await?; + log::debug!("mapped-addr: {}", resp); + + let _ = finished_rx.recv().await; + + c1.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_client_with_stun_send_binding_request_to_timeout() -> Result<()> { + //env_logger::init(); + + let c = create_listening_test_client(10).await?; + + let to = lookup_host(true, "127.0.0.1:9").await?; + + let result = c.send_binding_request_to(&to.to_string()).await; + assert!(result.is_err(), "expected error, but got ok"); + + c.close().await?; + + Ok(()) +} + +struct TestAuthHandler; +impl AuthHandler for TestAuthHandler { + fn auth_handle(&self, username: &str, realm: &str, _src_addr: SocketAddr) -> Result> { + Ok(generate_auth_key(username, realm, "pass")) + } +} diff --git a/rtc-turn/src/client/mod.rs b/rtc-turn/src/client/mod.rs new file mode 100644 index 0000000..5df3bfd --- /dev/null +++ b/rtc-turn/src/client/mod.rs @@ -0,0 +1,572 @@ +#[cfg(test)] +mod client_test; + +pub mod binding; +pub mod periodic_timer; +pub mod permission; +pub mod relay_conn; +pub mod transaction; + +use base64::prelude::BASE64_STANDARD; +use base64::Engine; +use binding::*; +use bytes::BytesMut; +use relay_conn::*; +use std::collections::VecDeque; +use std::net::SocketAddr; +use std::str::FromStr; +use std::time::Instant; + +use stun::attributes::*; +use stun::error_code::*; +use stun::fingerprint::*; +use stun::integrity::*; +use stun::message::*; +use stun::textattrs::*; +use stun::xoraddr::*; +use transaction::*; + +use crate::proto::chandata::*; +use crate::proto::data::*; +use crate::proto::lifetime::*; +use crate::proto::peeraddr::*; +use crate::proto::relayaddr::*; +use crate::proto::reqtrans::*; +use crate::proto::PROTO_UDP; +use shared::error::{Error, Result}; +use stun::Transmit; + +const DEFAULT_RTO_IN_MS: u16 = 200; +const MAX_DATA_BUFFER_SIZE: usize = u16::MAX as usize; // message size limit for Chromium +const MAX_READ_QUEUE_SIZE: usize = 1024; + +// interval [msec] +// 0: 0 ms +500 +// 1: 500 ms +1000 +// 2: 1500 ms +2000 +// 3: 3500 ms +4000 +// 4: 7500 ms +8000 +// 5: 15500 ms +16000 +// 6: 31500 ms +32000 +// -: 63500 ms failed + +/// ClientConfig is a bag of config parameters for Client. +pub struct ClientConfig { + pub stun_serv_addr: String, // STUN server address (e.g. "stun.abc.com:3478") + pub turn_serv_addr: String, // TURN server address (e.g. "turn.abc.com:3478") + pub username: String, + pub password: String, + pub realm: String, + pub software: String, + pub rto_in_ms: u16, +} + +/// Client is a STUN client +struct Client { + stun_serv_addr: Option, + turn_serv_addr: SocketAddr, + username: Username, + password: String, + realm: Realm, + integrity: MessageIntegrity, + software: Software, + tr_map: TransactionMap, + binding_mgr: BindingManager, + rto_in_ms: u16, + transmits: VecDeque, +} + +impl RelayConnObserver for Client { + /// turn_server_addr return the TURN server address + fn turn_server_addr(&self) -> SocketAddr { + self.turn_serv_addr + } + + /// username returns username + fn username(&self) -> Username { + self.username.clone() + } + + /// realm return realm + fn realm(&self) -> Realm { + self.realm.clone() + } + + /// WriteTo sends data to the specified destination using the base socket. + fn write(&mut self, data: &[u8], remote: SocketAddr) -> Result { + let n = data.len(); + self.transmits.push_back(Transmit { + now: Instant::now(), + remote, + ecn: None, + local_ip: None, + payload: BytesMut::from(data), + }); + Ok(n) + } + + // PerformTransaction performs STUN transaction + fn perform_transaction( + &mut self, + msg: &Message, + to: &str, + ignore_result: bool, + ) -> Result { + let tr_key = BASE64_STANDARD.encode(msg.transaction_id.0); + + let mut tr = Transaction::new(TransactionConfig { + key: tr_key.clone(), + raw: msg.raw.clone(), + to: to.to_string(), + interval: self.rto_in_ms, + ignore_result, + }); + let result_ch_rx = tr.get_result_channel(); + + log::trace!("start {} transaction {} to {}", msg.typ, tr_key, tr.to); + { + let mut tm = self.tr_map.lock().await; + tm.insert(tr_key.clone(), tr); + } + + self.conn + .send_to(&msg.raw, SocketAddr::from_str(to)?) + .await?; + + let conn2 = Arc::clone(&self.conn); + let tr_map2 = Arc::clone(&self.tr_map); + { + let mut tm = self.tr_map.lock().await; + if let Some(tr) = tm.get(&tr_key) { + tr.start_rtx_timer(conn2, tr_map2).await; + } + } + + // If dontWait is true, get the transaction going and return immediately + if ignore_result { + return Ok(TransactionResult::default()); + } + + // wait_for_result waits for the transaction result + if let Some(mut result_ch_rx) = result_ch_rx { + match result_ch_rx.recv().await { + Some(tr) => Ok(tr), + None => Err(Error::ErrTransactionClosed), + } + } else { + Err(Error::ErrWaitForResultOnNonResultTransaction) + } + } +} + +impl Client { + /// new returns a new Client instance. listeningAddress is the address and port to listen on, default "0.0.0.0:0" + pub fn new(config: ClientConfig) -> Result { + let stun_serv_addr = if config.stun_serv_addr.is_empty() { + None + } else { + Some(SocketAddr::from_str(config.stun_serv_addr.as_str())?) + }; + + let turn_serv_addr = if config.turn_serv_addr.is_empty() { + return Err(Error::ErrNilTurnSocket); + } else { + SocketAddr::from_str(config.turn_serv_addr.as_str())? + }; + + Ok(Client { + stun_serv_addr, + turn_serv_addr, + username: Username::new(ATTR_USERNAME, config.username), + password: config.password, + realm: Realm::new(ATTR_REALM, config.realm), + software: Software::new(ATTR_SOFTWARE, config.software), + tr_map: TransactionMap::new(), + binding_mgr: BindingManager::new(), + rto_in_ms: if config.rto_in_ms != 0 { + config.rto_in_ms + } else { + DEFAULT_RTO_IN_MS + }, + integrity: MessageIntegrity::new_short_term_integrity(String::new()), + transmits: VecDeque::new(), + }) + } + + // stun_server_addr return the STUN server address + fn stun_server_addr(&self) -> Option { + self.stun_serv_addr + } + + /// Listen will have this client start listening on the relay_conn provided via the config. + /// This is optional. If not used, you will need to call handle_inbound method + /// to supply incoming data, instead. + pub fn listen(&mut self) -> Result<()> { + let conn = Arc::clone(&self.conn); + let stun_serv_str = self.stun_serv_addr.clone(); + let tr_map = Arc::clone(&self.tr_map); + let read_ch_tx = Arc::clone(&self.read_ch_tx); + let binding_mgr = Arc::clone(&self.binding_mgr); + + tokio::spawn(async move { + let mut buf = vec![0u8; MAX_DATA_BUFFER_SIZE]; + loop { + //TODO: gracefully exit loop + let (n, from) = match conn.recv_from(&mut buf).await { + Ok((n, from)) => (n, from), + Err(err) => { + log::debug!("exiting read loop: {}", err); + break; + } + }; + + log::debug!("received {} bytes of udp from {}", n, from); + + if let Err(err) = ClientInternal::handle_inbound( + &read_ch_tx, + &buf[..n], + from, + &stun_serv_str, + &tr_map, + &binding_mgr, + ) + .await + { + log::debug!("exiting read loop: {}", err); + break; + } + } + }); + + Ok(()) + } + + // handle_inbound handles data received. + // This method handles incoming packet demultiplex it by the source address + // and the types of the message. + // This return a booleen (handled or not) and if there was an error. + // Caller should check if the packet was handled by this client or not. + // If not handled, it is assumed that the packet is application data. + // If an error is returned, the caller should discard the packet regardless. + async fn handle_inbound( + read_ch_tx: &Arc>>>, + data: &[u8], + from: SocketAddr, + stun_serv_str: &str, + tr_map: &Arc>, + binding_mgr: &Arc>, + ) -> Result<()> { + // +-------------------+-------------------------------+ + // | Return Values | | + // +-------------------+ Meaning / Action | + // | handled | error | | + // |=========+=========+===============================+ + // | false | nil | Handle the packet as app data | + // |---------+---------+-------------------------------+ + // | true | nil | Nothing to do | + // |---------+---------+-------------------------------+ + // | false | error | (shouldn't happen) | + // |---------+---------+-------------------------------+ + // | true | error | Error occurred while handling | + // +---------+---------+-------------------------------+ + // Possible causes of the error: + // - Malformed packet (parse error) + // - STUN message was a request + // - Non-STUN message from the STUN server + + if is_message(data) { + ClientInternal::handle_stun_message(tr_map, read_ch_tx, data, from).await + } else if ChannelData::is_channel_data(data) { + ClientInternal::handle_channel_data(binding_mgr, read_ch_tx, data).await + } else if !stun_serv_str.is_empty() && from.to_string() == *stun_serv_str { + // received from STUN server but it is not a STUN message + Err(Error::ErrNonStunmessage) + } else { + // assume, this is an application data + log::trace!("non-STUN/TURN packect, unhandled"); + Ok(()) + } + } + + async fn handle_stun_message( + tr_map: &Arc>, + read_ch_tx: &Arc>>>, + data: &[u8], + mut from: SocketAddr, + ) -> Result<()> { + let mut msg = Message::new(); + msg.raw = data.to_vec(); + msg.decode()?; + + if msg.typ.class == CLASS_REQUEST { + return Err(Error::Other(format!( + "{:?} : {}", + Error::ErrUnexpectedStunrequestMessage, + msg + ))); + } + + if msg.typ.class == CLASS_INDICATION { + if msg.typ.method == METHOD_DATA { + let mut peer_addr = PeerAddress::default(); + peer_addr.get_from(&msg)?; + from = SocketAddr::new(peer_addr.ip, peer_addr.port); + + let mut data = Data::default(); + data.get_from(&msg)?; + + log::debug!("data indication received from {}", from); + + let _ = ClientInternal::handle_inbound_relay_conn(read_ch_tx, &data.0, from).await; + } + + return Ok(()); + } + + // This is a STUN response message (transactional) + // The type is either: + // - stun.ClassSuccessResponse + // - stun.ClassErrorResponse + + let tr_key = BASE64_STANDARD.encode(msg.transaction_id.0); + + let mut tm = tr_map.lock().await; + if tm.find(&tr_key).is_none() { + // silently discard + log::debug!("no transaction for {}", msg); + return Ok(()); + } + + if let Some(mut tr) = tm.delete(&tr_key) { + // End the transaction + tr.stop_rtx_timer(); + + if !tr + .write_result(TransactionResult { + msg, + from, + retries: tr.retries(), + ..Default::default() + }) + .await + { + log::debug!("no listener for msg.raw {:?}", data); + } + } + + Ok(()) + } + + async fn handle_channel_data( + binding_mgr: &Arc>, + read_ch_tx: &Arc>>>, + data: &[u8], + ) -> Result<()> { + let mut ch_data = ChannelData { + raw: data.to_vec(), + ..Default::default() + }; + ch_data.decode()?; + + let addr = ClientInternal::find_addr_by_channel_number(binding_mgr, ch_data.number.0) + .await + .ok_or(Error::ErrChannelBindNotFound)?; + + log::trace!( + "channel data received from {} (ch={})", + addr, + ch_data.number.0 + ); + + let _ = ClientInternal::handle_inbound_relay_conn(read_ch_tx, &ch_data.data, addr).await; + + Ok(()) + } + + // handle_inbound_relay_conn passes inbound data in RelayConn + async fn handle_inbound_relay_conn( + read_ch_tx: &Arc>>>, + data: &[u8], + from: SocketAddr, + ) -> Result<()> { + let read_ch_tx_opt = read_ch_tx.lock().await; + log::debug!("read_ch_tx_opt = {}", read_ch_tx_opt.is_some()); + if let Some(tx) = &*read_ch_tx_opt { + log::debug!("try_send data = {:?}, from = {}", data, from); + if tx + .try_send(InboundData { + data: data.to_vec(), + from, + }) + .is_err() + { + log::warn!("receive buffer full"); + } + Ok(()) + } else { + Err(Error::ErrAlreadyClosed) + } + } + + /// Close closes this client + pub fn close(&mut self) { + { + let mut read_ch_tx = self.read_ch_tx.lock().await; + read_ch_tx.take(); + } + { + let mut tm = self.tr_map.lock().await; + tm.close_and_delete_all(); + } + } + + /// send_binding_request_to sends a new STUN request to the given transport address + pub fn send_binding_request_to(&mut self, to: SocketAddr) -> Result { + let msg = { + let attrs: Vec> = if !self.software.text.is_empty() { + vec![ + Box::new(TransactionId::new()), + Box::new(BINDING_REQUEST), + Box::new(self.software.clone()), + ] + } else { + vec![Box::new(TransactionId::new()), Box::new(BINDING_REQUEST)] + }; + + let mut msg = Message::new(); + msg.build(&attrs)?; + msg + }; + + log::debug!("client.SendBindingRequestTo call PerformTransaction 1"); + let tr_res = self.perform_transaction(&msg, to, false)?; + + let mut refl_addr = XorMappedAddress::default(); + refl_addr.get_from(&tr_res.msg)?; + + Ok(SocketAddr::new(refl_addr.ip, refl_addr.port)) + } + + /// send_binding_request sends a new STUN request to the STUN server + pub fn send_binding_request(&mut self) -> Result { + if self.stun_serv_addr.is_empty() { + Err(Error::ErrStunserverAddressNotSet) + } else { + self.send_binding_request_to(&self.stun_serv_addr) + } + } + + // find_addr_by_channel_number returns a peer address associated with the + // channel number on this UDPConn + async fn find_addr_by_channel_number( + binding_mgr: &Arc>, + ch_num: u16, + ) -> Option { + let bm = binding_mgr.lock().await; + bm.find_by_number(ch_num).map(|b| b.addr) + } + + /// Allocate sends a TURN allocation request to the given transport address + pub fn allocate(&mut self) -> Result { + { + let read_ch_tx = self.read_ch_tx.lock().await; + log::debug!("allocate check: read_ch_tx_opt = {}", read_ch_tx.is_some()); + if read_ch_tx.is_some() { + return Err(Error::ErrOneAllocateOnly); + } + } + + let mut msg = Message::new(); + msg.build(&[ + Box::new(TransactionId::new()), + Box::new(MessageType::new(METHOD_ALLOCATE, CLASS_REQUEST)), + Box::new(RequestedTransport { + protocol: PROTO_UDP, + }), + Box::new(FINGERPRINT), + ])?; + + log::debug!("client.Allocate call PerformTransaction 1"); + let tr_res = self + .perform_transaction(&msg, &self.turn_serv_addr.clone(), false) + .await?; + let res = tr_res.msg; + + // Anonymous allocate failed, trying to authenticate. + let nonce = Nonce::get_from_as(&res, ATTR_NONCE)?; + self.realm = Realm::get_from_as(&res, ATTR_REALM)?; + + self.integrity = MessageIntegrity::new_long_term_integrity( + self.username.text.clone(), + self.realm.text.clone(), + self.password.clone(), + ); + + // Trying to authorize. + msg.build(&[ + Box::new(TransactionId::new()), + Box::new(MessageType::new(METHOD_ALLOCATE, CLASS_REQUEST)), + Box::new(RequestedTransport { + protocol: PROTO_UDP, + }), + Box::new(self.username.clone()), + Box::new(self.realm.clone()), + Box::new(nonce.clone()), + Box::new(self.integrity.clone()), + Box::new(FINGERPRINT), + ])?; + + log::debug!("client.Allocate call PerformTransaction 2"); + let tr_res = self + .perform_transaction(&msg, &self.turn_serv_addr.clone(), false) + .await?; + let res = tr_res.msg; + + if res.typ.class == CLASS_ERROR_RESPONSE { + let mut code = ErrorCodeAttribute::default(); + let result = code.get_from(&res); + if result.is_err() { + return Err(Error::Other(format!("{}", res.typ))); + } else { + return Err(Error::Other(format!("{} (error {})", res.typ, code))); + } + } + + // Getting relayed addresses from response. + let mut relayed = RelayedAddress::default(); + relayed.get_from(&res)?; + let relayed_addr = SocketAddr::new(relayed.ip, relayed.port); + + // Getting lifetime from response + let mut lifetime = Lifetime::default(); + lifetime.get_from(&res)?; + + let (read_ch_tx, read_ch_rx) = mpsc::channel(MAX_READ_QUEUE_SIZE); + { + let mut read_ch_tx_opt = self.read_ch_tx.lock().await; + *read_ch_tx_opt = Some(read_ch_tx); + log::debug!("allocate: read_ch_tx_opt = {}", read_ch_tx_opt.is_some()); + } + + Ok(RelayConnConfig { + relayed_addr, + integrity: self.integrity.clone(), + nonce, + lifetime: lifetime.0, + binding_mgr: Arc::clone(&self.binding_mgr), + read_ch_rx: Arc::new(Mutex::new(read_ch_rx)), + }) + } +} + +/*TODO: +impl Client { + pub async fn allocate(&self) -> Result { + let config = { + let mut ci = self.client_internal.lock().await; + ci.allocate().await? + }; + + Ok(RelayConn::new(Arc::clone(&self.client_internal), config).await) + } +}*/ diff --git a/rtc-turn/src/client/periodic_timer.rs b/rtc-turn/src/client/periodic_timer.rs new file mode 100644 index 0000000..034ddf5 --- /dev/null +++ b/rtc-turn/src/client/periodic_timer.rs @@ -0,0 +1,89 @@ +#[cfg(test)] +mod periodic_timer_test; + +use std::time::Duration; + +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum TimerIdRefresh { + #[default] + Alloc, + Perms, +} + +// PeriodicTimerTimeoutHandler is a handler called on timeout +#[async_trait] +pub trait PeriodicTimerTimeoutHandler { + async fn on_timeout(&mut self, id: TimerIdRefresh); +} + +// PeriodicTimer is a periodic timer +#[derive(Default)] +pub struct PeriodicTimer { + id: TimerIdRefresh, + interval: Duration, + close_tx: Mutex>>, +} + +impl PeriodicTimer { + // create a new timer + pub fn new(id: TimerIdRefresh, interval: Duration) -> Self { + PeriodicTimer { + id, + interval, + close_tx: Mutex::new(None), + } + } + + // Start starts the timer. + pub async fn start( + &self, + timeout_handler: Arc>, + ) -> bool { + // this is a noop if the timer is always running + { + let close_tx = self.close_tx.lock().await; + if close_tx.is_some() { + return false; + } + } + + let (close_tx, mut close_rx) = mpsc::channel(1); + let interval = self.interval; + let id = self.id; + + tokio::spawn(async move { + loop { + let timer = tokio::time::sleep(interval); + tokio::pin!(timer); + + tokio::select! { + _ = timer.as_mut() => { + let mut handler = timeout_handler.lock().await; + handler.on_timeout(id).await; + } + _ = close_rx.recv() => break, + } + } + }); + + { + let mut close = self.close_tx.lock().await; + *close = Some(close_tx); + } + + true + } + + // Stop stops the timer. + pub async fn stop(&self) { + let mut close_tx = self.close_tx.lock().await; + close_tx.take(); + } + + // is_running tests if the timer is running. + // Debug purpose only + pub async fn is_running(&self) -> bool { + let close_tx = self.close_tx.lock().await; + close_tx.is_some() + } +} diff --git a/rtc-turn/src/client/periodic_timer/periodic_timer_test.rs b/rtc-turn/src/client/periodic_timer/periodic_timer_test.rs new file mode 100644 index 0000000..b67b31d --- /dev/null +++ b/rtc-turn/src/client/periodic_timer/periodic_timer_test.rs @@ -0,0 +1,37 @@ +use super::*; +use crate::error::Result; + +struct DummyPeriodicTimerTimeoutHandler; + +#[async_trait] +impl PeriodicTimerTimeoutHandler for DummyPeriodicTimerTimeoutHandler { + async fn on_timeout(&mut self, id: TimerIdRefresh) { + assert_eq!(id, TimerIdRefresh::Perms); + } +} + +#[tokio::test] +async fn test_periodic_timer() -> Result<()> { + let timer_id = TimerIdRefresh::Perms; + let rt = PeriodicTimer::new(timer_id, Duration::from_millis(50)); + let dummy1 = Arc::new(Mutex::new(DummyPeriodicTimerTimeoutHandler {})); + let dummy2 = Arc::clone(&dummy1); + + assert!(!rt.is_running().await, "should not be running yet"); + + let ok = rt.start(dummy1).await; + assert!(ok, "should be true"); + assert!(rt.is_running().await, "should be running"); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let ok = rt.start(dummy2).await; + assert!(!ok, "start again is noop"); + + tokio::time::sleep(Duration::from_millis(120)).await; + rt.stop().await; + + assert!(!rt.is_running().await, "should not be running"); + + Ok(()) +} diff --git a/rtc-turn/src/client/permission.rs b/rtc-turn/src/client/permission.rs new file mode 100644 index 0000000..234cbf4 --- /dev/null +++ b/rtc-turn/src/client/permission.rs @@ -0,0 +1,71 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::Arc; + +#[derive(Default, Copy, Clone, PartialEq, Debug)] +pub(crate) enum PermState { + #[default] + Idle = 0, + Permitted = 1, +} + +impl From for PermState { + fn from(v: u8) -> Self { + match v { + 0 => PermState::Idle, + _ => PermState::Permitted, + } + } +} + +#[derive(Default)] +pub(crate) struct Permission { + st: AtomicU8, //PermState, +} + +impl Permission { + pub(crate) fn set_state(&self, state: PermState) { + self.st.store(state as u8, Ordering::SeqCst); + } + + pub(crate) fn state(&self) -> PermState { + self.st.load(Ordering::SeqCst).into() + } +} + +// Thread-safe Permission map +#[derive(Default)] +pub(crate) struct PermissionMap { + perm_map: HashMap>, +} + +impl PermissionMap { + pub(crate) fn new() -> PermissionMap { + PermissionMap { + perm_map: HashMap::new(), + } + } + + pub(crate) fn insert(&mut self, addr: &SocketAddr, p: Arc) { + self.perm_map.insert(addr.ip().to_string(), p); + } + + pub(crate) fn find(&self, addr: &SocketAddr) -> Option<&Arc> { + self.perm_map.get(&addr.ip().to_string()) + } + + pub(crate) fn delete(&mut self, addr: &SocketAddr) { + self.perm_map.remove(&addr.ip().to_string()); + } + + pub(crate) fn addrs(&self) -> Vec { + let mut a = vec![]; + for k in self.perm_map.keys() { + if let Ok(ip) = k.parse() { + a.push(SocketAddr::new(ip, 0)); + } + } + a + } +} diff --git a/rtc-turn/src/client/relay_conn.rs b/rtc-turn/src/client/relay_conn.rs new file mode 100644 index 0000000..d0aaa70 --- /dev/null +++ b/rtc-turn/src/client/relay_conn.rs @@ -0,0 +1,615 @@ +#[cfg(test)] +mod relay_conn_test; + +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +use stun::attributes::*; +use stun::error_code::*; +use stun::fingerprint::*; +use stun::integrity::*; +use stun::message::*; +use stun::textattrs::*; + +use super::binding::*; +use super::periodic_timer::*; +use super::permission::*; +use super::transaction::*; +use crate::proto; + +use shared::error::{Error, Result}; +use stun::Transmit; + +const PERM_REFRESH_INTERVAL: Duration = Duration::from_secs(120); +const MAX_RETRY_ATTEMPTS: u16 = 3; + +pub(crate) struct InboundData { + pub(crate) data: Vec, + pub(crate) from: SocketAddr, +} + +pub trait RelayConnObserver { + fn write(&mut self, data: &[u8], remote: SocketAddr) -> Result; + + fn turn_server_addr(&self) -> String; + fn username(&self) -> Username; + fn realm(&self) -> Realm; + fn perform_transaction( + &mut self, + msg: &Message, + to: SocketAddr, + ignore_result: bool, + ) -> Result; +} + +// RelayConnConfig is a set of configuration params use by NewUDPConn +pub(crate) struct RelayConnConfig { + pub(crate) relayed_addr: SocketAddr, + pub(crate) integrity: MessageIntegrity, + pub(crate) nonce: Nonce, + pub(crate) lifetime: Duration, + pub(crate) binding_mgr: BindingManager, +} + +pub struct RelayConnInternal { + obs: Arc>, + relayed_addr: SocketAddr, + perm_map: PermissionMap, + binding_mgr: Arc>, + integrity: MessageIntegrity, + nonce: Nonce, + lifetime: Duration, +} + +// RelayConn is the implementation of the Conn interfaces for UDP Relayed network connections. +pub struct RelayConn { + relayed_addr: SocketAddr, + read_ch_rx: Arc>>, + relay_conn: Arc>>, + refresh_alloc_timer: PeriodicTimer, + refresh_perms_timer: PeriodicTimer, +} + +impl RelayConn { + // new creates a new instance of UDPConn + pub(crate) async fn new(obs: Arc>, config: RelayConnConfig) -> Self { + log::debug!("initial lifetime: {} seconds", config.lifetime.as_secs()); + + let c = RelayConn { + refresh_alloc_timer: PeriodicTimer::new(TimerIdRefresh::Alloc, config.lifetime / 2), + refresh_perms_timer: PeriodicTimer::new(TimerIdRefresh::Perms, PERM_REFRESH_INTERVAL), + relayed_addr: config.relayed_addr, + read_ch_rx: Arc::clone(&config.read_ch_rx), + relay_conn: Arc::new(Mutex::new(RelayConnInternal::new(obs, config))), + }; + + let rci1 = Arc::clone(&c.relay_conn); + let rci2 = Arc::clone(&c.relay_conn); + + if c.refresh_alloc_timer.start(rci1).await { + log::debug!("refresh_alloc_timer started"); + } + if c.refresh_perms_timer.start(rci2).await { + log::debug!("refresh_perms_timer started"); + } + + c + } +} +/*TODO: +#[async_trait] +impl Conn for RelayConn { + async fn connect(&self, _addr: SocketAddr) -> Result<()> { + Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) + } + + async fn recv(&self, _buf: &mut [u8]) -> Result { + Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) + } + + // ReadFrom reads a packet from the connection, + // copying the payload into p. It returns the number of + // bytes copied into p and the return address that + // was on the packet. + // It returns the number of bytes read (0 <= n <= len(p)) + // and any error encountered. Callers should always process + // the n > 0 bytes returned before considering the error err. + // ReadFrom can be made to time out and return + // an Error with Timeout() == true after a fixed time limit; + // see SetDeadline and SetReadDeadline. + async fn recv_from(&self, p: &mut [u8]) -> Result<(usize, SocketAddr)> { + let mut read_ch_rx = self.read_ch_rx.lock().await; + + if let Some(ib_data) = read_ch_rx.recv().await { + let n = ib_data.data.len(); + if p.len() < n { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + Error::ErrShortBuffer.to_string(), + ) + .into()); + } + p[..n].copy_from_slice(&ib_data.data); + Ok((n, ib_data.from)) + } else { + Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + Error::ErrAlreadyClosed.to_string(), + ) + .into()) + } + } + + async fn send(&self, _buf: &[u8]) -> Result { + Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) + } + + // write_to writes a packet with payload p to addr. + // write_to can be made to time out and return + // an Error with Timeout() == true after a fixed time limit; + // see SetDeadline and SetWriteDeadline. + // On packet-oriented connections, write timeouts are rare. + async fn send_to(&self, p: &[u8], addr: SocketAddr) -> Result { + let mut relay_conn = self.relay_conn.lock().await; + match relay_conn.send_to(p, addr).await { + Ok(n) => Ok(n), + Err(err) => Err(io::Error::new(io::ErrorKind::Other, err.to_string()).into()), + } + } + + // LocalAddr returns the local network address. + fn local_addr(&self) -> Result { + Ok(self.relayed_addr) + } + + fn remote_addr(&self) -> Option { + None + } + + // Close closes the connection. + // Any blocked ReadFrom or write_to operations will be unblocked and return errors. + async fn close(&self) -> Result<()> { + self.refresh_alloc_timer.stop().await; + self.refresh_perms_timer.stop().await; + + let mut relay_conn = self.relay_conn.lock().await; + let _ = relay_conn + .close() + .await + .map_err(|err| util::Error::Other(format!("{err}"))); + Ok(()) + } +}*/ + +impl RelayConnInternal { + // new creates a new instance of UDPConn + fn new(obs: Arc>, config: RelayConnConfig) -> Self { + RelayConnInternal { + obs, + relayed_addr: config.relayed_addr, + perm_map: PermissionMap::new(), + binding_mgr: config.binding_mgr, + integrity: config.integrity, + nonce: config.nonce, + lifetime: config.lifetime, + } + } + + // write_to writes a packet with payload p to addr. + // write_to can be made to time out and return + // an Error with Timeout() == true after a fixed time limit; + // see SetDeadline and SetWriteDeadline. + // On packet-oriented connections, write timeouts are rare. + async fn send_to(&mut self, p: &[u8], addr: SocketAddr) -> Result { + // check if we have a permission for the destination IP addr + let perm = if let Some(perm) = self.perm_map.find(&addr) { + Arc::clone(perm) + } else { + let perm = Arc::new(Permission::default()); + self.perm_map.insert(&addr, Arc::clone(&perm)); + perm + }; + + let mut result = Ok(()); + for _ in 0..MAX_RETRY_ATTEMPTS { + result = self.create_perm(&perm, addr).await; + if let Err(err) = &result { + if Error::ErrTryAgain != *err { + break; + } + } + } + result?; + + let number = { + let (bind_st, bind_at, bind_number, bind_addr) = { + let mut binding_mgr = self.binding_mgr.lock().await; + let b = if let Some(b) = binding_mgr.find_by_addr(&addr) { + b + } else { + binding_mgr + .create(addr) + .ok_or_else(|| Error::Other("Addr not found".to_owned()))? + }; + (b.state(), b.refreshed_at(), b.number, b.addr) + }; + + if bind_st == BindingState::Idle + || bind_st == BindingState::Request + || bind_st == BindingState::Failed + { + // block only callers with the same binding until + // the binding transaction has been complete + // binding state may have been changed while waiting. check again. + if bind_st == BindingState::Idle { + let binding_mgr = Arc::clone(&self.binding_mgr); + let rc_obs = Arc::clone(&self.obs); + let nonce = self.nonce.clone(); + let integrity = self.integrity.clone(); + { + let mut bm = binding_mgr.lock().await; + if let Some(b) = bm.get_by_addr(&bind_addr) { + b.set_state(BindingState::Request); + } + } + tokio::spawn(async move { + let result = RelayConnInternal::bind( + rc_obs, + bind_addr, + bind_number, + nonce, + integrity, + ) + .await; + + { + let mut bm = binding_mgr.lock().await; + if let Err(err) = result { + if Error::ErrUnexpectedResponse != err { + bm.delete_by_addr(&bind_addr); + } else if let Some(b) = bm.get_by_addr(&bind_addr) { + b.set_state(BindingState::Failed); + } + + // keep going... + log::warn!("bind() failed: {}", err); + } else if let Some(b) = bm.get_by_addr(&bind_addr) { + b.set_state(BindingState::Ready); + } + } + }); + } + + // send data using SendIndication + let peer_addr = socket_addr2peer_address(&addr); + let mut msg = Message::new(); + msg.build(&[ + Box::new(TransactionId::new()), + Box::new(MessageType::new(METHOD_SEND, CLASS_INDICATION)), + Box::new(proto::data::Data(p.to_vec())), + Box::new(peer_addr), + Box::new(FINGERPRINT), + ])?; + + // indication has no transaction (fire-and-forget) + let obs = self.obs.lock().await; + let turn_server_addr = obs.turn_server_addr(); + return Ok(obs.write_to(&msg.raw, &turn_server_addr).await?); + } + + // binding is either ready + + // check if the binding needs a refresh + if bind_st == BindingState::Ready + && Instant::now() + .checked_duration_since(bind_at) + .unwrap_or_else(|| Duration::from_secs(0)) + > Duration::from_secs(5 * 60) + { + let binding_mgr = Arc::clone(&self.binding_mgr); + let rc_obs = Arc::clone(&self.obs); + let nonce = self.nonce.clone(); + let integrity = self.integrity.clone(); + { + let mut bm = binding_mgr.lock().await; + if let Some(b) = bm.get_by_addr(&bind_addr) { + b.set_state(BindingState::Refresh); + } + } + tokio::spawn(async move { + let result = + RelayConnInternal::bind(rc_obs, bind_addr, bind_number, nonce, integrity) + .await; + + { + let mut bm = binding_mgr.lock().await; + if let Err(err) = result { + if Error::ErrUnexpectedResponse != err { + bm.delete_by_addr(&bind_addr); + } else if let Some(b) = bm.get_by_addr(&bind_addr) { + b.set_state(BindingState::Failed); + } + + // keep going... + log::warn!("bind() for refresh failed: {}", err); + } else if let Some(b) = bm.get_by_addr(&bind_addr) { + b.set_refreshed_at(Instant::now()); + b.set_state(BindingState::Ready); + } + } + }); + } + + bind_number + }; + + // send via ChannelData + self.send_channel_data(p, number).await + } + + // This func-block would block, per destination IP (, or perm), until + // the perm state becomes "requested". Purpose of this is to guarantee + // the order of packets (within the same perm). + // Note that CreatePermission transaction may not be complete before + // all the data transmission. This is done assuming that the request + // will be mostly likely successful and we can tolerate some loss of + // UDP packet (or reorder), inorder to minimize the latency in most cases. + async fn create_perm(&mut self, perm: &Arc, addr: SocketAddr) -> Result<()> { + if perm.state() == PermState::Idle { + // punch a hole! (this would block a bit..) + if let Err(err) = self.create_permissions(&[addr]).await { + self.perm_map.delete(&addr); + return Err(err); + } + perm.set_state(PermState::Permitted); + } + Ok(()) + } + + async fn send_channel_data(&self, data: &[u8], ch_num: u16) -> Result { + let mut ch_data = proto::chandata::ChannelData { + data: data.to_vec(), + number: proto::channum::ChannelNumber(ch_num), + ..Default::default() + }; + ch_data.encode(); + + let obs = self.obs.lock().await; + Ok(obs.write_to(&ch_data.raw, &obs.turn_server_addr()).await?) + } + + async fn create_permissions(&mut self, addrs: &[SocketAddr]) -> Result<()> { + let res = { + let msg = { + let obs = self.obs.lock().await; + let mut setters: Vec> = vec![ + Box::new(TransactionId::new()), + Box::new(MessageType::new(METHOD_CREATE_PERMISSION, CLASS_REQUEST)), + ]; + + for addr in addrs { + setters.push(Box::new(socket_addr2peer_address(addr))); + } + + setters.push(Box::new(obs.username())); + setters.push(Box::new(obs.realm())); + setters.push(Box::new(self.nonce.clone())); + setters.push(Box::new(self.integrity.clone())); + setters.push(Box::new(FINGERPRINT)); + + let mut msg = Message::new(); + msg.build(&setters)?; + msg + }; + + let mut obs = self.obs.lock().await; + let turn_server_addr = obs.turn_server_addr(); + + log::debug!("UDPConn.createPermissions call PerformTransaction 1"); + let tr_res = obs + .perform_transaction(&msg, &turn_server_addr, false) + .await?; + + tr_res.msg + }; + + if res.typ.class == CLASS_ERROR_RESPONSE { + let mut code = ErrorCodeAttribute::default(); + let result = code.get_from(&res); + if result.is_err() { + return Err(Error::Other(format!("{}", res.typ))); + } else if code.code == CODE_STALE_NONCE { + self.set_nonce_from_msg(&res); + return Err(Error::ErrTryAgain); + } else { + return Err(Error::Other(format!("{} (error {})", res.typ, code))); + } + } + + Ok(()) + } + + pub fn set_nonce_from_msg(&mut self, msg: &Message) { + // Update nonce + match Nonce::get_from_as(msg, ATTR_NONCE) { + Ok(nonce) => { + self.nonce = nonce; + log::debug!("refresh allocation: 438, got new nonce."); + } + Err(_) => log::warn!("refresh allocation: 438 but no nonce."), + } + } + + // Close closes the connection. + // Any blocked ReadFrom or write_to operations will be unblocked and return errors. + pub async fn close(&mut self) -> Result<()> { + self.refresh_allocation(Duration::from_secs(0), true /* dontWait=true */) + .await + } + + async fn refresh_allocation(&mut self, lifetime: Duration, dont_wait: bool) -> Result<()> { + let res = { + let mut obs = self.obs.lock().await; + + let mut msg = Message::new(); + msg.build(&[ + Box::new(TransactionId::new()), + Box::new(MessageType::new(METHOD_REFRESH, CLASS_REQUEST)), + Box::new(proto::lifetime::Lifetime(lifetime)), + Box::new(obs.username()), + Box::new(obs.realm()), + Box::new(self.nonce.clone()), + Box::new(self.integrity.clone()), + Box::new(FINGERPRINT), + ])?; + + log::debug!("send refresh request (dont_wait={})", dont_wait); + let turn_server_addr = obs.turn_server_addr(); + let tr_res = obs + .perform_transaction(&msg, &turn_server_addr, dont_wait) + .await?; + + if dont_wait { + log::debug!("refresh request sent"); + return Ok(()); + } + + log::debug!("refresh request sent, and waiting response"); + + tr_res.msg + }; + + if res.typ.class == CLASS_ERROR_RESPONSE { + let mut code = ErrorCodeAttribute::default(); + let result = code.get_from(&res); + if result.is_err() { + return Err(Error::Other(format!("{}", res.typ))); + } else if code.code == CODE_STALE_NONCE { + self.set_nonce_from_msg(&res); + return Err(Error::ErrTryAgain); + } else { + return Ok(()); + } + } + + // Getting lifetime from response + let mut updated_lifetime = proto::lifetime::Lifetime::default(); + updated_lifetime.get_from(&res)?; + + self.lifetime = updated_lifetime.0; + log::debug!("updated lifetime: {} seconds", self.lifetime.as_secs()); + Ok(()) + } + + async fn refresh_permissions(&mut self) -> Result<()> { + let addrs = self.perm_map.addrs(); + if addrs.is_empty() { + log::debug!("no permission to refresh"); + return Ok(()); + } + + if let Err(err) = self.create_permissions(&addrs).await { + if Error::ErrTryAgain != err { + log::error!("fail to refresh permissions: {}", err); + } + return Err(err); + } + + log::debug!("refresh permissions successful"); + Ok(()) + } + + async fn bind( + rc_obs: Arc>, + bind_addr: SocketAddr, + bind_number: u16, + nonce: Nonce, + integrity: MessageIntegrity, + ) -> Result<()> { + let (msg, turn_server_addr) = { + let obs = rc_obs.lock().await; + + let setters: Vec> = vec![ + Box::new(TransactionId::new()), + Box::new(MessageType::new(METHOD_CHANNEL_BIND, CLASS_REQUEST)), + Box::new(socket_addr2peer_address(&bind_addr)), + Box::new(proto::channum::ChannelNumber(bind_number)), + Box::new(obs.username()), + Box::new(obs.realm()), + Box::new(nonce), + Box::new(integrity), + Box::new(FINGERPRINT), + ]; + + let mut msg = Message::new(); + msg.build(&setters)?; + + (msg, obs.turn_server_addr()) + }; + + log::debug!("UDPConn.bind call PerformTransaction 1"); + let tr_res = { + let mut obs = rc_obs.lock().await; + obs.perform_transaction(&msg, &turn_server_addr, false) + .await? + }; + + let res = tr_res.msg; + + if res.typ != MessageType::new(METHOD_CHANNEL_BIND, CLASS_SUCCESS_RESPONSE) { + return Err(Error::ErrUnexpectedResponse); + } + + log::debug!("channel binding successful: {} {}", bind_addr, bind_number); + + // Success. + Ok(()) + } +} + +#[async_trait] +impl PeriodicTimerTimeoutHandler for RelayConnInternal { + async fn on_timeout(&mut self, id: TimerIdRefresh) { + log::debug!("refresh timer {:?} expired", id); + match id { + TimerIdRefresh::Alloc => { + let lifetime = self.lifetime; + // limit the max retries on errTryAgain to 3 + // when stale nonce returns, sencond retry should succeed + let mut result = Ok(()); + for _ in 0..MAX_RETRY_ATTEMPTS { + result = self.refresh_allocation(lifetime, false).await; + if let Err(err) = &result { + if Error::ErrTryAgain != *err { + break; + } + } + } + if result.is_err() { + log::warn!("refresh allocation failed"); + } + } + TimerIdRefresh::Perms => { + let mut result = Ok(()); + for _ in 0..MAX_RETRY_ATTEMPTS { + result = self.refresh_permissions().await; + if let Err(err) = &result { + if Error::ErrTryAgain != *err { + break; + } + } + } + if result.is_err() { + log::warn!("refresh permissions failed"); + } + } + } + } +} + +fn socket_addr2peer_address(addr: &SocketAddr) -> proto::peeraddr::PeerAddress { + proto::peeraddr::PeerAddress { + ip: addr.ip(), + port: addr.port(), + } +} diff --git a/rtc-turn/src/client/relay_conn/relay_conn_test.rs b/rtc-turn/src/client/relay_conn/relay_conn_test.rs new file mode 100644 index 0000000..58f9802 --- /dev/null +++ b/rtc-turn/src/client/relay_conn/relay_conn_test.rs @@ -0,0 +1,84 @@ +use std::net::Ipv4Addr; + +use super::*; +use crate::error::Result; + +struct DummyRelayConnObserver { + turn_server_addr: String, + username: Username, + realm: Realm, +} + +#[async_trait] +impl RelayConnObserver for DummyRelayConnObserver { + fn turn_server_addr(&self) -> String { + self.turn_server_addr.clone() + } + + fn username(&self) -> Username { + self.username.clone() + } + + fn realm(&self) -> Realm { + self.realm.clone() + } + + async fn write_to(&self, _data: &[u8], _to: &str) -> std::result::Result { + Ok(0) + } + + async fn perform_transaction( + &mut self, + _msg: &Message, + _to: &str, + _dont_wait: bool, + ) -> Result { + Err(Error::ErrFakeErr) + } +} + +#[tokio::test] +async fn test_relay_conn() -> Result<()> { + let obs = DummyRelayConnObserver { + turn_server_addr: String::new(), + username: Username::new(ATTR_USERNAME, "username".to_owned()), + realm: Realm::new(ATTR_REALM, "realm".to_owned()), + }; + + let (_read_ch_tx, read_ch_rx) = mpsc::channel(100); + + let config = RelayConnConfig { + relayed_addr: SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0), + integrity: MessageIntegrity::default(), + nonce: Nonce::new(ATTR_NONCE, "nonce".to_owned()), + lifetime: Duration::from_secs(0), + binding_mgr: Arc::new(Mutex::new(BindingManager::new())), + read_ch_rx: Arc::new(Mutex::new(read_ch_rx)), + }; + + let rc = RelayConn::new(Arc::new(Mutex::new(obs)), config).await; + + let rci = rc.relay_conn.lock().await; + let (bind_addr, bind_number) = { + let mut bm = rci.binding_mgr.lock().await; + let b = bm + .create(SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 1234)) + .unwrap(); + (b.addr, b.number) + }; + + //let binding_mgr = Arc::clone(&rci.binding_mgr); + let rc_obs = Arc::clone(&rci.obs); + let nonce = rci.nonce.clone(); + let integrity = rci.integrity.clone(); + + if let Err(err) = + RelayConnInternal::bind(rc_obs, bind_addr, bind_number, nonce, integrity).await + { + assert!(Error::ErrUnexpectedResponse != err); + } else { + panic!("should fail"); + } + + Ok(()) +} diff --git a/rtc-turn/src/client/transaction.rs b/rtc-turn/src/client/transaction.rs new file mode 100644 index 0000000..8a68d70 --- /dev/null +++ b/rtc-turn/src/client/transaction.rs @@ -0,0 +1,277 @@ +use std::collections::HashMap; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::str::FromStr; +use std::sync::atomic::{AtomicU16, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use stun::message::*; + +use shared::error::Error; + +const MAX_RTX_INTERVAL_IN_MS: u16 = 1600; +const MAX_RTX_COUNT: u16 = 7; // total 7 requests (Rc) + +async fn on_rtx_timeout( + conn: &Arc, + tr_map: &Arc>, + tr_key: &str, + n_rtx: u16, +) -> bool { + let mut tm = tr_map.lock().await; + let (tr_raw, tr_to) = match tm.find(tr_key) { + Some(tr) => (tr.raw.clone(), tr.to.clone()), + None => return true, // already gone + }; + + if n_rtx == MAX_RTX_COUNT { + // all retransmisstions failed + if let Some(tr) = tm.delete(tr_key) { + if !tr + .write_result(TransactionResult { + err: Some(Error::Other(format!( + "{:?} {}", + Error::ErrAllRetransmissionsFailed, + tr_key + ))), + ..Default::default() + }) + .await + { + log::debug!("no listener for transaction"); + } + } + return true; + } + + log::trace!( + "retransmitting transaction {} to {} (n_rtx={})", + tr_key, + tr_to, + n_rtx + ); + + let dst = match SocketAddr::from_str(&tr_to) { + Ok(dst) => dst, + Err(_) => return false, + }; + + if conn.send_to(&tr_raw, dst).await.is_err() { + if let Some(tr) = tm.delete(tr_key) { + if !tr + .write_result(TransactionResult { + err: Some(Error::Other(format!( + "{:?} {}", + Error::ErrAllRetransmissionsFailed, + tr_key + ))), + ..Default::default() + }) + .await + { + log::debug!("no listener for transaction"); + } + } + return true; + } + + false +} + +// TransactionResult is a bag of result values of a transaction +#[derive(Debug)] //Clone +pub struct TransactionResult { + pub msg: Message, + pub from: SocketAddr, + pub retries: u16, + pub err: Option, +} + +impl Default for TransactionResult { + fn default() -> Self { + TransactionResult { + msg: Message::default(), + from: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0), + retries: 0, + err: None, + } + } +} + +// TransactionConfig is a set of config params used by NewTransaction +#[derive(Default)] +pub struct TransactionConfig { + pub key: String, + pub raw: Vec, + pub to: String, + pub interval: u16, + pub ignore_result: bool, // true to throw away the result of this transaction (it will not be readable using wait_for_result) +} + +// Transaction represents a transaction +#[derive(Debug)] +pub struct Transaction { + pub key: String, + pub raw: Vec, + pub to: String, + pub n_rtx: Arc, + pub interval: Arc, + timer_ch_tx: Option>, + result_ch_tx: Option>, + result_ch_rx: Option>, +} + +impl Default for Transaction { + fn default() -> Self { + Transaction { + key: String::new(), + raw: vec![], + to: String::new(), + n_rtx: Arc::new(AtomicU16::new(0)), + interval: Arc::new(AtomicU16::new(0)), + //timer: None, + timer_ch_tx: None, + result_ch_tx: None, + result_ch_rx: None, + } + } +} + +impl Transaction { + // NewTransaction creates a new instance of Transaction + pub fn new(config: TransactionConfig) -> Self { + let (result_ch_tx, result_ch_rx) = if !config.ignore_result { + let (tx, rx) = mpsc::channel(1); + (Some(tx), Some(rx)) + } else { + (None, None) + }; + + Transaction { + key: config.key, + raw: config.raw, + to: config.to, + interval: Arc::new(AtomicU16::new(config.interval)), + result_ch_tx, + result_ch_rx, + ..Default::default() + } + } + + // start_rtx_timer starts the transaction timer + pub async fn start_rtx_timer( + &mut self, + conn: Arc, + tr_map: Arc>, + ) { + let (timer_ch_tx, mut timer_ch_rx) = mpsc::channel(1); + self.timer_ch_tx = Some(timer_ch_tx); + let (n_rtx, interval, key) = (self.n_rtx.clone(), self.interval.clone(), self.key.clone()); + + tokio::spawn(async move { + let mut done = false; + while !done { + let timer = tokio::time::sleep(Duration::from_millis( + interval.load(Ordering::SeqCst) as u64, + )); + tokio::pin!(timer); + + tokio::select! { + _ = timer.as_mut() => { + let rtx = n_rtx.fetch_add(1, Ordering::SeqCst); + + let mut val = interval.load(Ordering::SeqCst); + val *= 2; + if val > MAX_RTX_INTERVAL_IN_MS { + val = MAX_RTX_INTERVAL_IN_MS; + } + interval.store(val, Ordering::SeqCst); + + done = on_rtx_timeout(&conn, &tr_map, &key, rtx + 1).await; + } + _ = timer_ch_rx.recv() => done = true, + } + } + }); + } + + // stop_rtx_timer stop the transaction timer + pub fn stop_rtx_timer(&mut self) { + if self.timer_ch_tx.is_some() { + self.timer_ch_tx.take(); + } + } + + // write_result writes the result to the result channel + pub async fn write_result(&self, res: TransactionResult) -> bool { + if let Some(result_ch) = &self.result_ch_tx { + result_ch.send(res).await.is_ok() + } else { + false + } + } + + pub fn get_result_channel(&mut self) -> Option> { + self.result_ch_rx.take() + } + + // Close closes the transaction + pub fn close(&mut self) { + if self.result_ch_tx.is_some() { + self.result_ch_tx.take(); + } + } + + // retries returns the number of retransmission it has made + pub fn retries(&self) -> u16 { + self.n_rtx.load(Ordering::SeqCst) + } +} + +// TransactionMap is a thread-safe transaction map +#[derive(Default, Debug)] +pub struct TransactionMap { + tr_map: HashMap, +} + +impl TransactionMap { + // NewTransactionMap create a new instance of the transaction map + pub fn new() -> TransactionMap { + TransactionMap { + tr_map: HashMap::new(), + } + } + + // Insert inserts a trasaction to the map + pub fn insert(&mut self, key: String, tr: Transaction) -> bool { + self.tr_map.insert(key, tr); + true + } + + // Find looks up a transaction by its key + pub fn find(&self, key: &str) -> Option<&Transaction> { + self.tr_map.get(key) + } + + pub fn get(&mut self, key: &str) -> Option<&mut Transaction> { + self.tr_map.get_mut(key) + } + + // Delete deletes a transaction by its key + pub fn delete(&mut self, key: &str) -> Option { + self.tr_map.remove(key) + } + + // close_and_delete_all closes and deletes all transactions + pub fn close_and_delete_all(&mut self) { + for tr in self.tr_map.values_mut() { + tr.close(); + } + self.tr_map.clear(); + } + + // Size returns the length of the transaction map + pub fn size(&self) -> usize { + self.tr_map.len() + } +} diff --git a/reserved/rtc-turn/src/lib.rs b/rtc-turn/src/lib.rs similarity index 55% rename from reserved/rtc-turn/src/lib.rs rename to rtc-turn/src/lib.rs index 80c6a85..6e8cdfc 100644 --- a/reserved/rtc-turn/src/lib.rs +++ b/rtc-turn/src/lib.rs @@ -1,2 +1,5 @@ #![warn(rust_2018_idioms)] #![allow(dead_code)] + +//TODO:pub mod client; +pub mod proto; diff --git a/rtc-turn/src/proto/addr.rs b/rtc-turn/src/proto/addr.rs new file mode 100644 index 0000000..dc69dab --- /dev/null +++ b/rtc-turn/src/proto/addr.rs @@ -0,0 +1,62 @@ +#[cfg(test)] +mod addr_test; + +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + +use super::*; + +// Addr is ip:port. +#[derive(PartialEq, Eq, Debug)] +pub struct Addr { + ip: IpAddr, + port: u16, +} + +impl Default for Addr { + fn default() -> Self { + Addr { + ip: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), + port: 0, + } + } +} + +impl fmt::Display for Addr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}:{}", self.ip, self.port) + } +} + +impl Addr { + // Network implements net.Addr. + pub fn network(&self) -> String { + "turn".to_owned() + } + + // sets addr. + pub fn from_socket_addr(n: &SocketAddr) -> Self { + let ip = n.ip(); + let port = n.port(); + + Addr { ip, port } + } + + // EqualIP returns true if a and b have equal IP addresses. + pub fn equal_ip(&self, other: &Addr) -> bool { + self.ip == other.ip + } +} + +// FiveTuple represents 5-TUPLE value. +#[derive(PartialEq, Eq, Default)] +pub struct FiveTuple { + pub client: Addr, + pub server: Addr, + pub proto: Protocol, +} + +impl fmt::Display for FiveTuple { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}->{} ({})", self.client, self.server, self.proto) + } +} diff --git a/rtc-turn/src/proto/addr/addr_test.rs b/rtc-turn/src/proto/addr/addr_test.rs new file mode 100644 index 0000000..70eba42 --- /dev/null +++ b/rtc-turn/src/proto/addr/addr_test.rs @@ -0,0 +1,104 @@ +use std::net::Ipv4Addr; + +use super::*; +use shared::error::Result; + +#[test] +fn test_addr_from_socket_addr() -> Result<()> { + let u = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234); + + let a = Addr::from_socket_addr(&u); + assert!( + u.ip() == a.ip || u.port() != a.port || u.to_string() != a.to_string(), + "not equal" + ); + assert_eq!(a.network(), "turn", "unexpected network"); + + Ok(()) +} + +#[test] +fn test_addr_equal_ip() -> Result<()> { + let a = Addr { + ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + port: 1337, + }; + let b = Addr { + ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + port: 1338, + }; + assert_ne!(a, b, "a != b"); + assert!(a.equal_ip(&b), "a.IP should equal to b.IP"); + + Ok(()) +} + +#[test] +fn test_five_tuple_equal() -> Result<()> { + let tests = vec![ + ("blank", FiveTuple::default(), FiveTuple::default(), true), + ( + "proto", + FiveTuple { + proto: PROTO_UDP, + ..Default::default() + }, + FiveTuple::default(), + false, + ), + ( + "server", + FiveTuple { + server: Addr { + port: 100, + ..Default::default() + }, + ..Default::default() + }, + FiveTuple::default(), + false, + ), + ( + "client", + FiveTuple { + client: Addr { + port: 100, + ..Default::default() + }, + ..Default::default() + }, + FiveTuple::default(), + false, + ), + ]; + + for (name, a, b, r) in tests { + let v = a == b; + assert_eq!(v, r, "({name}) {a} [{v}!={r}] {b}"); + } + + Ok(()) +} + +#[test] +fn test_five_tuple_string() -> Result<()> { + let s = FiveTuple { + proto: PROTO_UDP, + server: Addr { + port: 100, + ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + }, + client: Addr { + port: 200, + ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + }, + } + .to_string(); + + assert_eq!( + s, "127.0.0.1:200->127.0.0.1:100 (UDP)", + "unexpected stringer output" + ); + + Ok(()) +} diff --git a/rtc-turn/src/proto/chandata.rs b/rtc-turn/src/proto/chandata.rs new file mode 100644 index 0000000..337f37b --- /dev/null +++ b/rtc-turn/src/proto/chandata.rs @@ -0,0 +1,111 @@ +#[cfg(test)] +mod chandata_test; + +use super::channum::*; +use shared::error::{Error, Result}; + +const PADDING: usize = 4; + +fn nearest_padded_value_length(l: usize) -> usize { + let mut n = PADDING * (l / PADDING); + if n < l { + n += PADDING; + } + n +} + +const CHANNEL_DATA_LENGTH_SIZE: usize = 2; +const CHANNEL_DATA_NUMBER_SIZE: usize = CHANNEL_DATA_LENGTH_SIZE; +const CHANNEL_DATA_HEADER_SIZE: usize = CHANNEL_DATA_LENGTH_SIZE + CHANNEL_DATA_NUMBER_SIZE; + +// ChannelData represents The ChannelData Message. +// +// See RFC 5766 Section 11.4 +#[derive(Default, Debug)] +pub struct ChannelData { + pub data: Vec, // can be subslice of Raw + pub number: ChannelNumber, + pub raw: Vec, +} + +impl PartialEq for ChannelData { + fn eq(&self, other: &Self) -> bool { + self.data == other.data && self.number == other.number + } +} + +impl ChannelData { + // Reset resets Length, Data and Raw length. + #[inline] + pub fn reset(&mut self) { + self.raw.clear(); + self.data.clear(); + } + + // Encode encodes ChannelData Message to Raw. + pub fn encode(&mut self) { + self.raw.clear(); + self.write_header(); + self.raw.extend_from_slice(&self.data); + let padded = nearest_padded_value_length(self.raw.len()); + let bytes_to_add = padded - self.raw.len(); + if bytes_to_add > 0 { + self.raw.extend_from_slice(&vec![0; bytes_to_add]); + } + } + + // Decode decodes The ChannelData Message from Raw. + pub fn decode(&mut self) -> Result<()> { + let buf = &self.raw; + if buf.len() < CHANNEL_DATA_HEADER_SIZE { + return Err(Error::ErrUnexpectedEof); + } + let num = u16::from_be_bytes([buf[0], buf[1]]); + self.number = ChannelNumber(num); + if !self.number.valid() { + return Err(Error::ErrInvalidChannelNumber); + } + let l = u16::from_be_bytes([ + buf[CHANNEL_DATA_NUMBER_SIZE], + buf[CHANNEL_DATA_NUMBER_SIZE + 1], + ]) as usize; + if l > buf[CHANNEL_DATA_HEADER_SIZE..].len() { + return Err(Error::ErrBadChannelDataLength); + } + self.data = buf[CHANNEL_DATA_HEADER_SIZE..CHANNEL_DATA_HEADER_SIZE + l].to_vec(); + + Ok(()) + } + + // WriteHeader writes channel number and length. + pub fn write_header(&mut self) { + if self.raw.len() < CHANNEL_DATA_HEADER_SIZE { + // Making WriteHeader call valid even when c.Raw + // is nil or len(c.Raw) is less than needed for header. + self.raw + .resize(self.raw.len() + CHANNEL_DATA_HEADER_SIZE, 0); + } + self.raw[..CHANNEL_DATA_NUMBER_SIZE].copy_from_slice(&self.number.0.to_be_bytes()); + self.raw[CHANNEL_DATA_NUMBER_SIZE..CHANNEL_DATA_HEADER_SIZE] + .copy_from_slice(&(self.data.len() as u16).to_be_bytes()); + } + + // is_channel_data returns true if buf looks like the ChannelData Message. + pub fn is_channel_data(buf: &[u8]) -> bool { + if buf.len() < CHANNEL_DATA_HEADER_SIZE { + return false; + } + + if u16::from_be_bytes([ + buf[CHANNEL_DATA_NUMBER_SIZE], + buf[CHANNEL_DATA_NUMBER_SIZE + 1], + ]) > buf[CHANNEL_DATA_HEADER_SIZE..].len() as u16 + { + return false; + } + + // Quick check for channel number. + let num = ChannelNumber(u16::from_be_bytes([buf[0], buf[1]])); + num.valid() + } +} diff --git a/rtc-turn/src/proto/chandata/chandata_test.rs b/rtc-turn/src/proto/chandata/chandata_test.rs new file mode 100644 index 0000000..9376f6f --- /dev/null +++ b/rtc-turn/src/proto/chandata/chandata_test.rs @@ -0,0 +1,211 @@ +use super::*; + +#[test] +fn test_channel_data_encode() -> Result<()> { + let mut d = ChannelData { + data: vec![1, 2, 3, 4], + number: ChannelNumber(MIN_CHANNEL_NUMBER + 1), + ..Default::default() + }; + d.encode(); + + let mut b = ChannelData::default(); + b.raw.extend_from_slice(&d.raw); + b.decode()?; + + assert_eq!(b, d, "not equal"); + + assert!( + ChannelData::is_channel_data(&b.raw) && ChannelData::is_channel_data(&d.raw), + "unexpected IsChannelData" + ); + + Ok(()) +} + +#[test] +fn test_channel_data_equal() -> Result<()> { + let tests = vec![ + ( + "equal", + ChannelData { + number: ChannelNumber(MIN_CHANNEL_NUMBER), + data: vec![1, 2, 3], + ..Default::default() + }, + ChannelData { + number: ChannelNumber(MIN_CHANNEL_NUMBER), + data: vec![1, 2, 3], + ..Default::default() + }, + true, + ), + ( + "number", + ChannelData { + number: ChannelNumber(MIN_CHANNEL_NUMBER + 1), + data: vec![1, 2, 3], + ..Default::default() + }, + ChannelData { + number: ChannelNumber(MIN_CHANNEL_NUMBER), + data: vec![1, 2, 3], + ..Default::default() + }, + false, + ), + ( + "length", + ChannelData { + number: ChannelNumber(MIN_CHANNEL_NUMBER), + data: vec![1, 2, 3, 4], + ..Default::default() + }, + ChannelData { + number: ChannelNumber(MIN_CHANNEL_NUMBER), + data: vec![1, 2, 3], + ..Default::default() + }, + false, + ), + ( + "data", + ChannelData { + number: ChannelNumber(MIN_CHANNEL_NUMBER), + data: vec![1, 2, 2], + ..Default::default() + }, + ChannelData { + number: ChannelNumber(MIN_CHANNEL_NUMBER), + data: vec![1, 2, 3], + ..Default::default() + }, + false, + ), + ]; + + for (name, a, b, r) in tests { + let v = a == b; + assert_eq!(v, r, "unexpected: ({name}) {r} != {r}"); + } + + Ok(()) +} + +#[test] +fn test_channel_data_decode() -> Result<()> { + let tests = vec![ + ("small", vec![1, 2, 3], Error::ErrUnexpectedEof), + ( + "zeroes", + vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + Error::ErrInvalidChannelNumber, + ), + ( + "bad chan number", + vec![63, 255, 0, 0, 0, 4, 0, 0, 1, 2, 3, 4], + Error::ErrInvalidChannelNumber, + ), + ( + "bad length", + vec![0x40, 0x40, 0x02, 0x23, 0x16, 0, 0, 0, 0, 0, 0, 0], + Error::ErrBadChannelDataLength, + ), + ]; + + for (name, buf, want_err) in tests { + let mut m = ChannelData { + raw: buf, + ..Default::default() + }; + if let Err(err) = m.decode() { + assert_eq!(want_err, err, "unexpected: ({name}) {want_err} != {err}"); + } else { + panic!("expected error, but got ok"); + } + } + + Ok(()) +} + +#[test] +fn test_channel_data_reset() -> Result<()> { + let mut d = ChannelData { + data: vec![1, 2, 3, 4], + number: ChannelNumber(MIN_CHANNEL_NUMBER + 1), + ..Default::default() + }; + d.encode(); + let mut buf = vec![0; d.raw.len()]; + buf.copy_from_slice(&d.raw); + d.reset(); + d.raw = buf; + d.decode()?; + + Ok(()) +} + +#[test] +fn test_is_channel_data() -> Result<()> { + let tests = vec![ + ("small", vec![1, 2, 3, 4], false), + ("zeroes", vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], false), + ]; + + for (name, buf, r) in tests { + let v = ChannelData::is_channel_data(&buf); + assert_eq!(v, r, "unexpected: ({name}) {r} != {v}"); + } + + Ok(()) +} + +const CHANDATA_TEST_HEX: [&str; 2] = [ + "40000064000100502112a442453731722f2b322b6e4e7a5800060009443758343a33776c59000000c0570004000003e7802a00081d5136dab65b169300250000002400046e001eff0008001465d11a330e104a9f5f598af4abc6a805f26003cf802800046b334442", + "4000022316fefd0000000000000011012c0b000120000100000000012000011d00011a308201163081bda003020102020900afe52871340bd13e300a06082a8648ce3d0403023011310f300d06035504030c06576562525443301e170d3138303831313033353230305a170d3138303931313033353230305a3011310f300d06035504030c065765625254433059301306072a8648ce3d020106082a8648ce3d030107034200048080e348bd41469cfb7a7df316676fd72a06211765a50a0f0b07526c872dcf80093ed5caa3f5a40a725dd74b41b79bdd19ee630c5313c8601d6983286c8722c1300a06082a8648ce3d0403020348003045022100d13a0a131bc2a9f27abd3d4c547f7ef172996a0c0755c707b6a3e048d8762ded0220055fc8182818a644a3d3b5b157304cc3f1421fadb06263bfb451cd28be4bc9ee16fefd0000000000000012002d10000021000200000000002120f7e23c97df45a96e13cb3e76b37eff5e73e2aee0b6415d29443d0bd24f578b7e16fefd000000000000001300580f00004c000300000000004c040300483046022100fdbb74eab1aca1532e6ac0ab267d5b83a24bb4d5d7d504936e2785e6e388b2bd022100f6a457b9edd9ead52a9d0e9a19240b3a68b95699546c044f863cf8349bc8046214fefd000000000000001400010116fefd0001000000000004003000010000000000040aae2421e7d549632a7def8ed06898c3c5b53f5b812a963a39ab6cdd303b79bdb237f3314c1da21b", +]; + +#[test] +fn test_chrome_channel_data() -> Result<()> { + let mut data = vec![]; + let mut messages = vec![]; + + // Decoding hex data into binary. + for h in &CHANDATA_TEST_HEX { + let b = match hex::decode(h) { + Ok(b) => b, + Err(_) => return Err(Error::Other("hex decode error".to_owned())), + }; + data.push(b); + } + + // All hex streams decoded to raw binary format and stored in data slice. + // Decoding packets to messages. + for packet in data { + let mut m = ChannelData { + raw: packet, + ..Default::default() + }; + + m.decode()?; + let mut encoded = ChannelData { + data: m.data.clone(), + number: m.number, + ..Default::default() + }; + encoded.encode(); + + let mut decoded = ChannelData { + raw: encoded.raw.clone(), + ..Default::default() + }; + + decoded.decode()?; + assert_eq!(decoded, m, "should be equal"); + + messages.push(m); + } + assert_eq!(messages.len(), 2, "unexpected message slice list"); + + Ok(()) +} diff --git a/rtc-turn/src/proto/channum.rs b/rtc-turn/src/proto/channum.rs new file mode 100644 index 0000000..d00d17a --- /dev/null +++ b/rtc-turn/src/proto/channum.rs @@ -0,0 +1,72 @@ +#[cfg(test)] +mod channnum_test; + +use std::fmt; + +use stun::attributes::*; +use stun::checks::*; +use stun::message::*; + +use shared::error::Result; + +// 16 bits of uint + 16 bits of RFFU = 0. +const CHANNEL_NUMBER_SIZE: usize = 4; + +// See https://tools.ietf.org/html/rfc5766#section-11: +// +// 0x4000 through 0x7FFF: These values are the allowed channel +// numbers (16,383 possible values). +pub const MIN_CHANNEL_NUMBER: u16 = 0x4000; +pub const MAX_CHANNEL_NUMBER: u16 = 0x7FFF; + +// ChannelNumber represents CHANNEL-NUMBER attribute. +// +// The CHANNEL-NUMBER attribute contains the number of the channel. +// +// RFC 5766 Section 14.1 +// encoded as uint16 +#[derive(Default, Eq, PartialEq, Debug, Copy, Clone, Hash)] +pub struct ChannelNumber(pub u16); + +impl fmt::Display for ChannelNumber { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl Setter for ChannelNumber { + // AddTo adds CHANNEL-NUMBER to message. + fn add_to(&self, m: &mut Message) -> Result<()> { + let mut v = vec![0; CHANNEL_NUMBER_SIZE]; + v[..2].copy_from_slice(&self.0.to_be_bytes()); + // v[2:4] are zeroes (RFFU = 0) + m.add(ATTR_CHANNEL_NUMBER, &v); + Ok(()) + } +} + +impl Getter for ChannelNumber { + // GetFrom decodes CHANNEL-NUMBER from message. + fn get_from(&mut self, m: &Message) -> Result<()> { + let v = m.get(ATTR_CHANNEL_NUMBER)?; + + check_size(ATTR_CHANNEL_NUMBER, v.len(), CHANNEL_NUMBER_SIZE)?; + + //_ = v[CHANNEL_NUMBER_SIZE-1] // asserting length + self.0 = u16::from_be_bytes([v[0], v[1]]); + // v[2:4] is RFFU and equals to 0. + Ok(()) + } +} + +impl ChannelNumber { + // is_channel_number_valid returns true if c in [0x4000, 0x7FFF]. + fn is_channel_number_valid(&self) -> bool { + self.0 >= MIN_CHANNEL_NUMBER && self.0 <= MAX_CHANNEL_NUMBER + } + + // Valid returns true if channel number has correct value that complies RFC 5766 Section 11 range. + pub fn valid(&self) -> bool { + self.is_channel_number_valid() + } +} diff --git a/rtc-turn/src/proto/channum/channnum_test.rs b/rtc-turn/src/proto/channum/channnum_test.rs new file mode 100644 index 0000000..70a5a71 --- /dev/null +++ b/rtc-turn/src/proto/channum/channnum_test.rs @@ -0,0 +1,82 @@ +use super::*; +use shared::error::Error; + +#[test] +fn test_channel_number_string() -> Result<()> { + let n = ChannelNumber(112); + assert_eq!(n.to_string(), "112", "bad string {n}, expected 112"); + Ok(()) +} + +/* +#[test] +fn test_channel_number_NoAlloc() -> Result<()> { + let mut m = Message::default(); + + if wasAllocs(func() { + // Case with ChannelNumber on stack. + n: = ChannelNumber(6) + n.AddTo(m) //nolint + m.Reset() + }) { + t.Error("Unexpected allocations") + } + + n: = ChannelNumber(12) + nP: = &n + if wasAllocs(func() { + // On heap. + nP.AddTo(m) //nolint + m.Reset() + }) { + t.Error("Unexpected allocations") + } + Ok(()) +} +*/ + +#[test] +fn test_channel_number_add_to() -> Result<()> { + let mut m = Message::new(); + let n = ChannelNumber(6); + n.add_to(&mut m)?; + m.write_header(); + + //"GetFrom" + { + let mut decoded = Message::new(); + decoded.write(&m.raw)?; + + let mut num_decoded = ChannelNumber::default(); + num_decoded.get_from(&decoded)?; + assert_eq!(num_decoded, n, "Decoded {num_decoded}, expected {n}"); + + //"HandleErr" + { + let mut m = Message::new(); + let mut n_handle = ChannelNumber::default(); + if let Err(err) = n_handle.get_from(&m) { + assert_eq!( + Error::ErrAttributeNotFound, + err, + "{err} should be not found" + ); + } else { + panic!("expected error, but got ok"); + } + + m.add(ATTR_CHANNEL_NUMBER, &[1, 2, 3]); + + if let Err(err) = n_handle.get_from(&m) { + assert!( + is_attr_size_invalid(&err), + "IsAttrSizeInvalid should be true" + ); + } else { + panic!("expected error, but got ok"); + } + } + } + + Ok(()) +} diff --git a/rtc-turn/src/proto/data.rs b/rtc-turn/src/proto/data.rs new file mode 100644 index 0000000..f7e8961 --- /dev/null +++ b/rtc-turn/src/proto/data.rs @@ -0,0 +1,35 @@ +#[cfg(test)] +mod data_test; + +use stun::attributes::*; +use stun::message::*; + +use shared::error::Result; + +// Data represents DATA attribute. +// +// The DATA attribute is present in all Send and Data indications. The +// value portion of this attribute is variable length and consists of +// the application data (that is, the data that would immediately follow +// the UDP header if the data was been sent directly between the client +// and the peer). +// +// RFC 5766 Section 14.4 +#[derive(Default, Debug, PartialEq, Eq)] +pub struct Data(pub Vec); + +impl Setter for Data { + // AddTo adds DATA to message. + fn add_to(&self, m: &mut Message) -> Result<()> { + m.add(ATTR_DATA, &self.0); + Ok(()) + } +} + +impl Getter for Data { + // GetFrom decodes DATA from message. + fn get_from(&mut self, m: &Message) -> Result<()> { + self.0 = m.get(ATTR_DATA)?; + Ok(()) + } +} diff --git a/rtc-turn/src/proto/data/data_test.rs b/rtc-turn/src/proto/data/data_test.rs new file mode 100644 index 0000000..f710c54 --- /dev/null +++ b/rtc-turn/src/proto/data/data_test.rs @@ -0,0 +1,34 @@ +use super::*; +use shared::error::Error; + +#[test] +fn test_data_add_to() -> Result<()> { + let mut m = Message::new(); + let d = Data(vec![1, 2, 33, 44, 0x13, 0xaf]); + d.add_to(&mut m)?; + m.write_header(); + + //"GetFrom" + { + let mut decoded = Message::new(); + decoded.write(&m.raw)?; + + let mut data_decoded = Data::default(); + data_decoded.get_from(&decoded)?; + assert_eq!(data_decoded, d); + + //"HandleErr" + { + let m = Message::new(); + let mut handle = Data::default(); + if let Err(err) = handle.get_from(&m) { + assert_eq!( + Error::ErrAttributeNotFound, + err, + "{err} should be not found" + ); + } + } + } + Ok(()) +} diff --git a/rtc-turn/src/proto/dontfrag.rs b/rtc-turn/src/proto/dontfrag.rs new file mode 100644 index 0000000..cfe936a --- /dev/null +++ b/rtc-turn/src/proto/dontfrag.rs @@ -0,0 +1,26 @@ +#[cfg(test)] +mod dontfrag_test; + +use shared::error::Result; +use stun::attributes::*; +use stun::message::*; + +// DontFragmentAttr represents DONT-FRAGMENT attribute. +#[derive(Debug, Default, PartialEq, Eq)] +pub struct DontFragmentAttr; + +impl Setter for DontFragmentAttr { + // AddTo adds DONT-FRAGMENT attribute to message. + fn add_to(&self, m: &mut Message) -> Result<()> { + m.add(ATTR_DONT_FRAGMENT, &[]); + Ok(()) + } +} + +impl Getter for DontFragmentAttr { + // get_from returns true if DONT-FRAGMENT attribute is set. + fn get_from(&mut self, m: &Message) -> Result<()> { + let _ = m.get(ATTR_DONT_FRAGMENT)?; + Ok(()) + } +} diff --git a/rtc-turn/src/proto/dontfrag/dontfrag_test.rs b/rtc-turn/src/proto/dontfrag/dontfrag_test.rs new file mode 100644 index 0000000..3e45f70 --- /dev/null +++ b/rtc-turn/src/proto/dontfrag/dontfrag_test.rs @@ -0,0 +1,27 @@ +use super::*; + +#[test] +fn test_dont_fragment_false() -> Result<()> { + let mut dont_fragment = DontFragmentAttr::default(); + + let mut m = Message::new(); + m.write_header(); + assert!(dont_fragment.get_from(&m).is_err(), "should not be set"); + + Ok(()) +} + +#[test] +fn test_dont_fragment_add_to() -> Result<()> { + let mut dont_fragment = DontFragmentAttr::default(); + + let mut m = Message::new(); + dont_fragment.add_to(&mut m)?; + m.write_header(); + + let mut decoded = Message::new(); + decoded.write(&m.raw)?; + assert!(dont_fragment.get_from(&m).is_ok(), "should be set"); + + Ok(()) +} diff --git a/rtc-turn/src/proto/evenport.rs b/rtc-turn/src/proto/evenport.rs new file mode 100644 index 0000000..3cea809 --- /dev/null +++ b/rtc-turn/src/proto/evenport.rs @@ -0,0 +1,64 @@ +#[cfg(test)] +mod evenport_test; + +use std::fmt; + +use shared::error::Result; +use stun::attributes::*; +use stun::checks::*; +use stun::message::*; + +// EvenPort represents EVEN-PORT attribute. +// +// This attribute allows the client to request that the port in the +// relayed transport address be even, and (optionally) that the server +// reserve the next-higher port number. +// +// RFC 5766 Section 14.6 +#[derive(Default, Debug, PartialEq, Eq)] +pub struct EvenPort { + // reserve_port means that the server is requested to reserve + // the next-higher port number (on the same IP address) + // for a subsequent allocation. + reserve_port: bool, +} + +impl fmt::Display for EvenPort { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.reserve_port { + write!(f, "reserve: true") + } else { + write!(f, "reserve: false") + } + } +} + +const EVEN_PORT_SIZE: usize = 1; +const FIRST_BIT_SET: u8 = 0b10000000; //FIXME? (1 << 8) - 1; + +impl Setter for EvenPort { + // AddTo adds EVEN-PORT to message. + fn add_to(&self, m: &mut Message) -> Result<()> { + let mut v = vec![0; EVEN_PORT_SIZE]; + if self.reserve_port { + // Set first bit to 1. + v[0] = FIRST_BIT_SET; + } + m.add(ATTR_EVEN_PORT, &v); + Ok(()) + } +} + +impl Getter for EvenPort { + // GetFrom decodes EVEN-PORT from message. + fn get_from(&mut self, m: &Message) -> Result<()> { + let v = m.get(ATTR_EVEN_PORT)?; + + check_size(ATTR_EVEN_PORT, v.len(), EVEN_PORT_SIZE)?; + + if v[0] & FIRST_BIT_SET > 0 { + self.reserve_port = true; + } + Ok(()) + } +} diff --git a/rtc-turn/src/proto/evenport/evenport_test.rs b/rtc-turn/src/proto/evenport/evenport_test.rs new file mode 100644 index 0000000..183de40 --- /dev/null +++ b/rtc-turn/src/proto/evenport/evenport_test.rs @@ -0,0 +1,79 @@ +use super::*; +use shared::error::Error; + +#[test] +fn test_even_port_string() -> Result<()> { + let mut p = EvenPort::default(); + assert_eq!( + p.to_string(), + "reserve: false", + "bad value {p} for reselve: false" + ); + + p.reserve_port = true; + assert_eq!( + p.to_string(), + "reserve: true", + "bad value {p} for reselve: true" + ); + + Ok(()) +} + +#[test] +fn test_even_port_false() -> Result<()> { + let mut m = Message::new(); + let p = EvenPort { + reserve_port: false, + }; + p.add_to(&mut m)?; + m.write_header(); + + let mut decoded = Message::new(); + let mut port = EvenPort::default(); + decoded.write(&m.raw)?; + port.get_from(&m)?; + assert_eq!(port, p); + + Ok(()) +} + +#[test] +fn test_even_port_add_to() -> Result<()> { + let mut m = Message::new(); + let p = EvenPort { reserve_port: true }; + p.add_to(&mut m)?; + m.write_header(); + //"GetFrom" + { + let mut decoded = Message::new(); + decoded.write(&m.raw)?; + let mut port = EvenPort::default(); + port.get_from(&decoded)?; + assert_eq!(port, p, "Decoded {port}, expected {p}"); + + //"HandleErr" + { + let mut m = Message::new(); + let mut handle = EvenPort::default(); + if let Err(err) = handle.get_from(&m) { + assert_eq!( + Error::ErrAttributeNotFound, + err, + "{err} should be not found" + ); + } + m.add(ATTR_EVEN_PORT, &[1, 2, 3]); + if let Err(err) = handle.get_from(&m) { + assert!( + is_attr_size_invalid(&err), + "IsAttrSizeInvalid should be true" + ); + } else { + panic!("expected error, but got ok"); + } + } + } + + Ok(()) +} diff --git a/rtc-turn/src/proto/lifetime.rs b/rtc-turn/src/proto/lifetime.rs new file mode 100644 index 0000000..363fb50 --- /dev/null +++ b/rtc-turn/src/proto/lifetime.rs @@ -0,0 +1,60 @@ +#[cfg(test)] +mod lifetime_test; + +use std::fmt; +use std::time::Duration; + +use shared::error::Result; +use stun::attributes::*; +use stun::checks::*; +use stun::message::*; + +// DEFAULT_LIFETIME in RFC 5766 is 10 minutes. +// +// RFC 5766 Section 2.2 +pub const DEFAULT_LIFETIME: Duration = Duration::from_secs(10 * 60); + +// Lifetime represents LIFETIME attribute. +// +// The LIFETIME attribute represents the duration for which the server +// will maintain an allocation in the absence of a refresh. The value +// portion of this attribute is 4-bytes long and consists of a 32-bit +// unsigned integral value representing the number of seconds remaining +// until expiration. +// +// RFC 5766 Section 14.2 +#[derive(Default, Debug, PartialEq, Eq)] +pub struct Lifetime(pub Duration); + +impl fmt::Display for Lifetime { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}s", self.0.as_secs()) + } +} + +// uint32 seconds +const LIFETIME_SIZE: usize = 4; // 4 bytes, 32 bits + +impl Setter for Lifetime { + // AddTo adds LIFETIME to message. + fn add_to(&self, m: &mut Message) -> Result<()> { + let mut v = vec![0; LIFETIME_SIZE]; + v.copy_from_slice(&(self.0.as_secs() as u32).to_be_bytes()); + m.add(ATTR_LIFETIME, &v); + Ok(()) + } +} + +impl Getter for Lifetime { + // GetFrom decodes LIFETIME from message. + fn get_from(&mut self, m: &Message) -> Result<()> { + let v = m.get(ATTR_LIFETIME)?; + + check_size(ATTR_LIFETIME, v.len(), LIFETIME_SIZE)?; + + let seconds = u32::from_be_bytes([v[0], v[1], v[2], v[3]]); + self.0 = Duration::from_secs(seconds as u64); + + Ok(()) + } +} diff --git a/rtc-turn/src/proto/lifetime/lifetime_test.rs b/rtc-turn/src/proto/lifetime/lifetime_test.rs new file mode 100644 index 0000000..9c1f9b8 --- /dev/null +++ b/rtc-turn/src/proto/lifetime/lifetime_test.rs @@ -0,0 +1,55 @@ +use super::*; +use shared::error::Error; + +#[test] +fn test_lifetime_string() -> Result<()> { + let l = Lifetime(Duration::from_secs(10)); + assert_eq!(l.to_string(), "10s", "bad string {l}, expected 10s"); + + Ok(()) +} + +#[test] +fn test_lifetime_add_to() -> Result<()> { + let mut m = Message::new(); + let l = Lifetime(Duration::from_secs(10)); + l.add_to(&mut m)?; + m.write_header(); + + //"GetFrom" + { + let mut decoded = Message::new(); + decoded.write(&m.raw)?; + + let mut life = Lifetime::default(); + life.get_from(&decoded)?; + assert_eq!(life, l, "Decoded {life}, expected {l}"); + + //"HandleErr" + { + let mut m = Message::new(); + let mut n_handle = Lifetime::default(); + if let Err(err) = n_handle.get_from(&m) { + assert_eq!( + Error::ErrAttributeNotFound, + err, + "{err} should be not found" + ); + } else { + panic!("expected error, but got ok"); + } + m.add(ATTR_LIFETIME, &[1, 2, 3]); + + if let Err(err) = n_handle.get_from(&m) { + assert!( + is_attr_size_invalid(&err), + "IsAttrSizeInvalid should be true" + ); + } else { + panic!("expected error, but got ok"); + } + } + } + + Ok(()) +} diff --git a/rtc-turn/src/proto/mod.rs b/rtc-turn/src/proto/mod.rs new file mode 100644 index 0000000..33e2b8f --- /dev/null +++ b/rtc-turn/src/proto/mod.rs @@ -0,0 +1,69 @@ +#[cfg(test)] +mod proto_test; + +pub mod addr; +pub mod chandata; +pub mod channum; +pub mod data; +pub mod dontfrag; +pub mod evenport; +pub mod lifetime; +pub mod peeraddr; +pub mod relayaddr; +pub mod reqfamily; +pub mod reqtrans; +pub mod rsrvtoken; + +use std::fmt; + +use stun::message::*; + +// proto implements RFC 5766 Traversal Using Relays around NAT. + +// protocol is IANA assigned protocol number. +#[derive(PartialEq, Eq, Default, Debug, Clone, Copy, Hash)] +pub struct Protocol(pub u8); + +// PROTO_UDP is IANA assigned protocol number for UDP. +pub const PROTO_TCP: Protocol = Protocol(6); +pub const PROTO_UDP: Protocol = Protocol(17); + +impl fmt::Display for Protocol { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let others = format!("{}", self.0); + let s = match *self { + PROTO_UDP => "UDP", + PROTO_TCP => "TCP", + _ => others.as_str(), + }; + + write!(f, "{s}") + } +} + +// Default ports for TURN from RFC 5766 Section 4. + +// DEFAULT_PORT for TURN is same as STUN. +pub const DEFAULT_PORT: u16 = stun::DEFAULT_PORT; +// DEFAULT_TLSPORT is for TURN over TLS and is same as STUN. +pub const DEFAULT_TLS_PORT: u16 = stun::DEFAULT_TLS_PORT; + +// create_permission_request is shorthand for create permission request type. +pub fn create_permission_request() -> MessageType { + MessageType::new(METHOD_CREATE_PERMISSION, CLASS_REQUEST) +} + +// allocate_request is shorthand for allocation request message type. +pub fn allocate_request() -> MessageType { + MessageType::new(METHOD_ALLOCATE, CLASS_REQUEST) +} + +// send_indication is shorthand for send indication message type. +pub fn send_indication() -> MessageType { + MessageType::new(METHOD_SEND, CLASS_INDICATION) +} + +// refresh_request is shorthand for refresh request message type. +pub fn refresh_request() -> MessageType { + MessageType::new(METHOD_REFRESH, CLASS_REQUEST) +} diff --git a/rtc-turn/src/proto/peeraddr.rs b/rtc-turn/src/proto/peeraddr.rs new file mode 100644 index 0000000..8cb35ab --- /dev/null +++ b/rtc-turn/src/proto/peeraddr.rs @@ -0,0 +1,72 @@ +#[cfg(test)] +mod peeraddr_test; + +use std::fmt; +use std::net::{IpAddr, Ipv4Addr}; + +use shared::error::Result; +use stun::attributes::*; +use stun::message::*; +use stun::xoraddr::*; + +// PeerAddress implements XOR-PEER-ADDRESS attribute. +// +// The XOR-PEER-ADDRESS specifies the address and port of the peer as +// seen from the TURN server. (For example, the peer's server-reflexive +// transport address if the peer is behind a NAT.) +// +// RFC 5766 Section 14.3 +#[derive(PartialEq, Eq, Debug)] +pub struct PeerAddress { + pub ip: IpAddr, + pub port: u16, +} + +impl Default for PeerAddress { + fn default() -> Self { + PeerAddress { + ip: IpAddr::V4(Ipv4Addr::from(0)), + port: 0, + } + } +} + +impl fmt::Display for PeerAddress { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.ip { + IpAddr::V4(_) => write!(f, "{}:{}", self.ip, self.port), + IpAddr::V6(_) => write!(f, "[{}]:{}", self.ip, self.port), + } + } +} + +impl Setter for PeerAddress { + // AddTo adds XOR-PEER-ADDRESS to message. + fn add_to(&self, m: &mut Message) -> Result<()> { + let a = XorMappedAddress { + ip: self.ip, + port: self.port, + }; + a.add_to_as(m, ATTR_XOR_PEER_ADDRESS) + } +} + +impl Getter for PeerAddress { + // GetFrom decodes XOR-PEER-ADDRESS from message. + fn get_from(&mut self, m: &Message) -> Result<()> { + let mut a = XorMappedAddress::default(); + a.get_from_as(m, ATTR_XOR_PEER_ADDRESS)?; + self.ip = a.ip; + self.port = a.port; + Ok(()) + } +} + +// XORPeerAddress implements XOR-PEER-ADDRESS attribute. +// +// The XOR-PEER-ADDRESS specifies the address and port of the peer as +// seen from the TURN server. (For example, the peer's server-reflexive +// transport address if the peer is behind a NAT.) +// +// RFC 5766 Section 14.3 +pub type XorPeerAddress = PeerAddress; diff --git a/rtc-turn/src/proto/peeraddr/peeraddr_test.rs b/rtc-turn/src/proto/peeraddr/peeraddr_test.rs new file mode 100644 index 0000000..e9b08d8 --- /dev/null +++ b/rtc-turn/src/proto/peeraddr/peeraddr_test.rs @@ -0,0 +1,26 @@ +use std::net::Ipv4Addr; + +use super::*; + +#[test] +fn test_peer_address() -> Result<()> { + // Simple tests because already tested in stun. + let a = PeerAddress { + ip: IpAddr::V4(Ipv4Addr::new(111, 11, 1, 2)), + port: 333, + }; + + assert_eq!(a.to_string(), "111.11.1.2:333", "invalid string"); + + let mut m = Message::new(); + a.add_to(&mut m)?; + m.write_header(); + + let mut decoded = Message::new(); + decoded.write(&m.raw)?; + + let mut a_got = PeerAddress::default(); + a_got.get_from(&decoded)?; + + Ok(()) +} diff --git a/rtc-turn/src/proto/proto_test.rs b/rtc-turn/src/proto/proto_test.rs new file mode 100644 index 0000000..dae4857 --- /dev/null +++ b/rtc-turn/src/proto/proto_test.rs @@ -0,0 +1,35 @@ +use super::*; +use shared::error::*; + +const CHROME_ALLOC_REQ_TEST_HEX: [&str; 4] = [ + "000300242112a442626b4a6849664c3630526863802f0016687474703a2f2f6c6f63616c686f73743a333030302f00000019000411000000", + "011300582112a442626b4a6849664c36305268630009001000000401556e617574686f72697a656400150010356130323039623563623830363130360014000b61312e63796465762e7275758022001a436f7475726e2d342e352e302e33202764616e204569646572272300", + "0003006c2112a442324e50695a437a4634535034802f0016687474703a2f2f6c6f63616c686f73743a333030302f000000190004110000000006000665726e61646f00000014000b61312e63796465762e7275000015001035613032303962356362383036313036000800145c8743f3b64bec0880cdd8d476d37b801a6c3d33", + "010300582112a442324e50695a437a4634535034001600080001fb922b1ab211002000080001adb2f49f38ae000d0004000002588022001a436f7475726e2d342e352e302e33202764616e204569646572277475000800145d7e85b767a519ffce91dbf0a96775e370db92e3", +]; + +#[test] +fn test_chrome_alloc_request() -> Result<()> { + let mut data = vec![]; + let mut messages = vec![]; + + // Decoding hex data into binary. + for h in &CHROME_ALLOC_REQ_TEST_HEX { + let b = match hex::decode(h) { + Ok(b) => b, + Err(_) => return Err(Error::Other("hex decode error".to_owned())), + }; + data.push(b); + } + + // All hex streams decoded to raw binary format and stored in data slice. + // Decoding packets to messages. + for packet in data { + let mut m = Message::new(); + m.write(&packet)?; + messages.push(m); + } + assert_eq!(messages.len(), 4, "unexpected message slice list"); + + Ok(()) +} diff --git a/rtc-turn/src/proto/relayaddr.rs b/rtc-turn/src/proto/relayaddr.rs new file mode 100644 index 0000000..7aa847b --- /dev/null +++ b/rtc-turn/src/proto/relayaddr.rs @@ -0,0 +1,70 @@ +#[cfg(test)] +mod relayaddr_test; + +use std::fmt; +use std::net::{IpAddr, Ipv4Addr}; + +use shared::error::Result; +use stun::attributes::*; +use stun::message::*; +use stun::xoraddr::*; + +// RelayedAddress implements XOR-RELAYED-ADDRESS attribute. +// +// It specifies the address and port that the server allocated to the +// client. It is encoded in the same way as XOR-MAPPED-ADDRESS. +// +// RFC 5766 Section 14.5 +#[derive(PartialEq, Eq, Debug)] +pub struct RelayedAddress { + pub ip: IpAddr, + pub port: u16, +} + +impl Default for RelayedAddress { + fn default() -> Self { + RelayedAddress { + ip: IpAddr::V4(Ipv4Addr::from(0)), + port: 0, + } + } +} + +impl fmt::Display for RelayedAddress { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.ip { + IpAddr::V4(_) => write!(f, "{}:{}", self.ip, self.port), + IpAddr::V6(_) => write!(f, "[{}]:{}", self.ip, self.port), + } + } +} + +impl Setter for RelayedAddress { + // AddTo adds XOR-PEER-ADDRESS to message. + fn add_to(&self, m: &mut Message) -> Result<()> { + let a = XorMappedAddress { + ip: self.ip, + port: self.port, + }; + a.add_to_as(m, ATTR_XOR_RELAYED_ADDRESS) + } +} + +impl Getter for RelayedAddress { + // GetFrom decodes XOR-PEER-ADDRESS from message. + fn get_from(&mut self, m: &Message) -> Result<()> { + let mut a = XorMappedAddress::default(); + a.get_from_as(m, ATTR_XOR_RELAYED_ADDRESS)?; + self.ip = a.ip; + self.port = a.port; + Ok(()) + } +} + +// XORRelayedAddress implements XOR-RELAYED-ADDRESS attribute. +// +// It specifies the address and port that the server allocated to the +// client. It is encoded in the same way as XOR-MAPPED-ADDRESS. +// +// RFC 5766 Section 14.5 +pub type XorRelayedAddress = RelayedAddress; diff --git a/rtc-turn/src/proto/relayaddr/relayaddr_test.rs b/rtc-turn/src/proto/relayaddr/relayaddr_test.rs new file mode 100644 index 0000000..71aa96f --- /dev/null +++ b/rtc-turn/src/proto/relayaddr/relayaddr_test.rs @@ -0,0 +1,26 @@ +use std::net::Ipv4Addr; + +use super::*; + +#[test] +fn test_relayed_address() -> Result<()> { + // Simple tests because already tested in stun. + let a = RelayedAddress { + ip: IpAddr::V4(Ipv4Addr::new(111, 11, 1, 2)), + port: 333, + }; + + assert_eq!(a.to_string(), "111.11.1.2:333", "invalid string"); + + let mut m = Message::new(); + a.add_to(&mut m)?; + m.write_header(); + + let mut decoded = Message::new(); + decoded.write(&m.raw)?; + + let mut a_got = RelayedAddress::default(); + a_got.get_from(&decoded)?; + + Ok(()) +} diff --git a/rtc-turn/src/proto/reqfamily.rs b/rtc-turn/src/proto/reqfamily.rs new file mode 100644 index 0000000..722ed3c --- /dev/null +++ b/rtc-turn/src/proto/reqfamily.rs @@ -0,0 +1,63 @@ +#[cfg(test)] +mod reqfamily_test; + +use std::fmt; + +use stun::attributes::*; +use stun::checks::*; +use stun::message::*; + +use shared::error::{Error, Result}; + +// Values for RequestedAddressFamily as defined in RFC 6156 Section 4.1.1. +pub const REQUESTED_FAMILY_IPV4: RequestedAddressFamily = RequestedAddressFamily(0x01); +pub const REQUESTED_FAMILY_IPV6: RequestedAddressFamily = RequestedAddressFamily(0x02); + +// RequestedAddressFamily represents the REQUESTED-ADDRESS-FAMILY Attribute as +// defined in RFC 6156 Section 4.1.1. +#[derive(Debug, Default, PartialEq, Eq)] +pub struct RequestedAddressFamily(pub u8); + +impl fmt::Display for RequestedAddressFamily { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + REQUESTED_FAMILY_IPV4 => "IPv4", + REQUESTED_FAMILY_IPV6 => "IPv6", + _ => "unknown", + }; + write!(f, "{s}") + } +} + +const REQUESTED_FAMILY_SIZE: usize = 4; + +impl Setter for RequestedAddressFamily { + // AddTo adds REQUESTED-ADDRESS-FAMILY to message. + fn add_to(&self, m: &mut Message) -> Result<()> { + let mut v = vec![0; REQUESTED_FAMILY_SIZE]; + v[0] = self.0; + // b[1:4] is RFFU = 0. + // The RFFU field MUST be set to zero on transmission and MUST be + // ignored on reception. It is reserved for future uses. + m.add(ATTR_REQUESTED_ADDRESS_FAMILY, &v); + Ok(()) + } +} + +impl Getter for RequestedAddressFamily { + // GetFrom decodes REQUESTED-ADDRESS-FAMILY from message. + fn get_from(&mut self, m: &Message) -> Result<()> { + let v = m.get(ATTR_REQUESTED_ADDRESS_FAMILY)?; + check_size( + ATTR_REQUESTED_ADDRESS_FAMILY, + v.len(), + REQUESTED_FAMILY_SIZE, + )?; + + if v[0] != REQUESTED_FAMILY_IPV4.0 && v[0] != REQUESTED_FAMILY_IPV6.0 { + return Err(Error::Other("ErrInvalidRequestedFamilyValue".into())); + } + self.0 = v[0]; + Ok(()) + } +} diff --git a/rtc-turn/src/proto/reqfamily/reqfamily_test.rs b/rtc-turn/src/proto/reqfamily/reqfamily_test.rs new file mode 100644 index 0000000..83c967e --- /dev/null +++ b/rtc-turn/src/proto/reqfamily/reqfamily_test.rs @@ -0,0 +1,78 @@ +use super::*; +use shared::error::Error; + +#[test] +fn test_requested_address_family_string() -> Result<()> { + assert_eq!( + REQUESTED_FAMILY_IPV4.to_string(), + "IPv4", + "bad string {}, expected {}", + REQUESTED_FAMILY_IPV4, + "IPv4" + ); + + assert_eq!( + REQUESTED_FAMILY_IPV6.to_string(), + "IPv6", + "bad string {}, expected {}", + REQUESTED_FAMILY_IPV6, + "IPv6" + ); + + assert_eq!( + RequestedAddressFamily(0x04).to_string(), + "unknown", + "should be unknown" + ); + + Ok(()) +} + +#[test] +fn test_requested_address_family_add_to() -> Result<()> { + let mut m = Message::new(); + let r = REQUESTED_FAMILY_IPV4; + r.add_to(&mut m)?; + m.write_header(); + + //"GetFrom" + { + let mut decoded = Message::new(); + decoded.write(&m.raw)?; + let mut req = RequestedAddressFamily::default(); + req.get_from(&decoded)?; + assert_eq!(req, r, "Decoded {req}, expected {r}"); + + //"HandleErr" + { + let mut m = Message::new(); + let mut handle = RequestedAddressFamily::default(); + if let Err(err) = handle.get_from(&m) { + assert_eq!( + Error::ErrAttributeNotFound, + err, + "{err} should be not found" + ); + } else { + panic!("expected error, but got ok"); + } + m.add(ATTR_REQUESTED_ADDRESS_FAMILY, &[1, 2, 3]); + if let Err(err) = handle.get_from(&m) { + assert!( + is_attr_size_invalid(&err), + "IsAttrSizeInvalid should be true" + ); + } else { + panic!("expected error, but got ok"); + } + m.reset(); + m.add(ATTR_REQUESTED_ADDRESS_FAMILY, &[5, 0, 0, 0]); + assert!( + handle.get_from(&m).is_err(), + "should error on invalid value" + ); + } + } + + Ok(()) +} diff --git a/rtc-turn/src/proto/reqtrans.rs b/rtc-turn/src/proto/reqtrans.rs new file mode 100644 index 0000000..44d8606 --- /dev/null +++ b/rtc-turn/src/proto/reqtrans.rs @@ -0,0 +1,55 @@ +#[cfg(test)] +mod reqtrans_test; + +use std::fmt; + +use stun::attributes::*; +use stun::checks::*; +use stun::message::*; + +use super::*; +use shared::error::Result; + +// RequestedTransport represents REQUESTED-TRANSPORT attribute. +// +// This attribute is used by the client to request a specific transport +// protocol for the allocated transport address. RFC 5766 only allows the use of +// codepoint 17 (User Datagram protocol). +// +// RFC 5766 Section 14.7 +#[derive(Default, Debug, PartialEq, Eq)] +pub struct RequestedTransport { + pub protocol: Protocol, +} + +impl fmt::Display for RequestedTransport { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "protocol: {}", self.protocol) + } +} + +const REQUESTED_TRANSPORT_SIZE: usize = 4; + +impl Setter for RequestedTransport { + // AddTo adds REQUESTED-TRANSPORT to message. + fn add_to(&self, m: &mut Message) -> Result<()> { + let mut v = vec![0; REQUESTED_TRANSPORT_SIZE]; + v[0] = self.protocol.0; + // b[1:4] is RFFU = 0. + // The RFFU field MUST be set to zero on transmission and MUST be + // ignored on reception. It is reserved for future uses. + m.add(ATTR_REQUESTED_TRANSPORT, &v); + Ok(()) + } +} + +impl Getter for RequestedTransport { + // GetFrom decodes REQUESTED-TRANSPORT from message. + fn get_from(&mut self, m: &Message) -> Result<()> { + let v = m.get(ATTR_REQUESTED_TRANSPORT)?; + + check_size(ATTR_REQUESTED_TRANSPORT, v.len(), REQUESTED_TRANSPORT_SIZE)?; + self.protocol = Protocol(v[0]); + Ok(()) + } +} diff --git a/rtc-turn/src/proto/reqtrans/reqtrans_test.rs b/rtc-turn/src/proto/reqtrans/reqtrans_test.rs new file mode 100644 index 0000000..60d6717 --- /dev/null +++ b/rtc-turn/src/proto/reqtrans/reqtrans_test.rs @@ -0,0 +1,76 @@ +use super::*; +use shared::error::Error; + +#[test] +fn test_requested_transport_string() -> Result<()> { + let mut r = RequestedTransport { + protocol: PROTO_UDP, + }; + assert_eq!( + r.to_string(), + "protocol: UDP", + "bad string {}, expected {}", + r, + "protocol: UDP", + ); + r.protocol = Protocol(254); + if r.to_string() != "protocol: 254" { + assert_eq!( + r.to_string(), + "protocol: UDP", + "bad string {}, expected {}", + r, + "protocol: 254", + ); + } + + Ok(()) +} + +#[test] +fn test_requested_transport_add_to() -> Result<()> { + let mut m = Message::new(); + let r = RequestedTransport { + protocol: PROTO_UDP, + }; + r.add_to(&mut m)?; + m.write_header(); + + //"GetFrom" + { + let mut decoded = Message::new(); + decoded.write(&m.raw)?; + let mut req = RequestedTransport { + protocol: PROTO_UDP, + }; + req.get_from(&decoded)?; + assert_eq!(req, r, "Decoded {req}, expected {r}"); + + //"HandleErr" + { + let mut m = Message::new(); + let mut handle = RequestedTransport::default(); + if let Err(err) = handle.get_from(&m) { + assert_eq!( + Error::ErrAttributeNotFound, + err, + "{err} should be not found" + ); + } else { + panic!("expected error, got ok"); + } + + m.add(ATTR_REQUESTED_TRANSPORT, &[1, 2, 3]); + if let Err(err) = handle.get_from(&m) { + assert!( + is_attr_size_invalid(&err), + "IsAttrSizeInvalid should be true" + ); + } else { + panic!("expected error, got ok"); + } + } + } + + Ok(()) +} diff --git a/rtc-turn/src/proto/rsrvtoken.rs b/rtc-turn/src/proto/rsrvtoken.rs new file mode 100644 index 0000000..2ecfa46 --- /dev/null +++ b/rtc-turn/src/proto/rsrvtoken.rs @@ -0,0 +1,41 @@ +#[cfg(test)] +mod rsrvtoken_test; + +use shared::error::Result; +use stun::attributes::*; +use stun::checks::*; +use stun::message::*; + +// ReservationToken represents RESERVATION-TOKEN attribute. +// +// The RESERVATION-TOKEN attribute contains a token that uniquely +// identifies a relayed transport address being held in reserve by the +// server. The server includes this attribute in a success response to +// tell the client about the token, and the client includes this +// attribute in a subsequent Allocate request to request the server use +// that relayed transport address for the allocation. +// +// RFC 5766 Section 14.9 +#[derive(Debug, Default, PartialEq, Eq)] +pub struct ReservationToken(pub Vec); + +const RESERVATION_TOKEN_SIZE: usize = 8; // 8 bytes + +impl Setter for ReservationToken { + // AddTo adds RESERVATION-TOKEN to message. + fn add_to(&self, m: &mut Message) -> Result<()> { + check_size(ATTR_RESERVATION_TOKEN, self.0.len(), RESERVATION_TOKEN_SIZE)?; + m.add(ATTR_RESERVATION_TOKEN, &self.0); + Ok(()) + } +} + +impl Getter for ReservationToken { + // GetFrom decodes RESERVATION-TOKEN from message. + fn get_from(&mut self, m: &Message) -> Result<()> { + let v = m.get(ATTR_RESERVATION_TOKEN)?; + check_size(ATTR_RESERVATION_TOKEN, v.len(), RESERVATION_TOKEN_SIZE)?; + self.0 = v; + Ok(()) + } +} diff --git a/rtc-turn/src/proto/rsrvtoken/rsrvtoken_test.rs b/rtc-turn/src/proto/rsrvtoken/rsrvtoken_test.rs new file mode 100644 index 0000000..4bd15b1 --- /dev/null +++ b/rtc-turn/src/proto/rsrvtoken/rsrvtoken_test.rs @@ -0,0 +1,61 @@ +use super::*; +use shared::error::Error; + +#[test] +fn test_reservation_token() -> Result<()> { + let mut m = Message::new(); + let mut v = vec![0; 8]; + v[2] = 33; + v[7] = 1; + let tk = ReservationToken(v); + tk.add_to(&mut m)?; + m.write_header(); + + //"HandleErr" + { + let bad_tk = ReservationToken(vec![34, 45]); + if let Err(err) = bad_tk.add_to(&mut m) { + assert!( + is_attr_size_invalid(&err), + "IsAttrSizeInvalid should be true" + ); + } else { + panic!("expected error, but got ok"); + } + } + + //"GetFrom" + { + let mut decoded = Message::new(); + decoded.write(&m.raw)?; + let mut tok = ReservationToken::default(); + tok.get_from(&decoded)?; + assert_eq!(tok, tk, "Decoded {tok:?}, expected {tk:?}"); + + //"HandleErr" + { + let mut m = Message::new(); + let mut handle = ReservationToken::default(); + if let Err(err) = handle.get_from(&m) { + assert_eq!( + Error::ErrAttributeNotFound, + err, + "{err} should be not found" + ); + } else { + panic!("expected error, but got ok"); + } + m.add(ATTR_RESERVATION_TOKEN, &[1, 2, 3]); + if let Err(err) = handle.get_from(&m) { + assert!( + is_attr_size_invalid(&err), + "IsAttrSizeInvalid should be true" + ); + } else { + panic!("expected error, got ok"); + } + } + } + + Ok(()) +}