Skip to content

Commit

Permalink
feat(packet): remove ack resend and add lock for socket
Browse files Browse the repository at this point in the history
  • Loading branch information
fu050409 committed Apr 24, 2024
1 parent 1418735 commit 6c103d4
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 236 deletions.
2 changes: 1 addition & 1 deletion src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async fn main() -> Result<()> {
path_route!(router, "/json" => json);
path_route!(router, "/alive" => alive);

let mut server = Server::new("0.0.0.0", 7076, router);
let server = Server::new("0.0.0.0", 7076, router);
server.run().await?;
}
_ => {
Expand Down
33 changes: 11 additions & 22 deletions src/models/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
use std::{collections::VecDeque, sync::Arc};

use anyhow::{Error, Result};
use tokio::{
net::TcpStream,
sync::{Mutex, RwLock},
task::JoinHandle,
};
use tokio::{net::TcpStream, sync::Mutex, task::JoinHandle};

use crate::exceptions::Exception;
#[cfg(feature = "python")]
Expand Down Expand Up @@ -125,7 +121,7 @@ pub struct Client {
pub entrance: String,
pub path: OblivionPath,
pub header: String,
pub session: Arc<RwLock<Session>>,
pub session: Arc<Session>,
pub responses: Arc<Mutex<VecDeque<Response>>>,
}

Expand All @@ -151,37 +147,31 @@ impl Client {
entrance: entrance.to_string(),
path,
header,
session: Arc::new(RwLock::new(session)),
session: Arc::new(session),
responses: Arc::new(Mutex::new(VecDeque::new())),
})
}

pub async fn send(&self, data: Vec<u8>, status_code: u32) -> Result<()> {
let session = self.session.read().await;
Ok(session.send(data, status_code).await?)
Ok(self.session.send(data, status_code).await?)
}

pub async fn send_json(&self, json: Value, status_code: u32) -> Result<()> {
let session = self.session.read().await;
Ok(session
.send(json.to_string().into_bytes(), status_code)
.await?)
Ok(self.session.send_json(json, status_code).await?)
}

pub async fn recv(&self) -> Result<Response> {
let session = self.session.read().await;
Ok(session.recv().await?)
Ok(self.session.recv().await?)
}

pub async fn listen(&self) -> Result<JoinHandle<Result<()>>> {
let session = Arc::clone(&self.session);
let responses = Arc::clone(&self.responses);
Ok(tokio::spawn(async move {
loop {
let rsess = session.read().await;
let mut wres = responses.lock().await;
if !rsess.closed().await {
match rsess.recv().await {
if !session.closed().await {
match session.recv().await {
Ok(res) => {
if &res.flag == &1 {
wres.push_back(res);
Expand All @@ -190,9 +180,9 @@ impl Client {
wres.push_back(res);
}
Err(e) => {
if !rsess.closed().await {
if !session.closed().await {
eprintln!("{:?}", e);
rsess.close().await?;
session.close().await?;
}
break;
}
Expand All @@ -210,8 +200,7 @@ impl Client {
}

pub async fn close(&self) -> Result<()> {
let session = self.session.read().await;
session.close().await?;
self.session.close().await?;
Ok(())
}
}
184 changes: 53 additions & 131 deletions src/models/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,13 @@ use crate::utils::gear::Socket;
use crate::utils::generator::{generate_random_salt, SharedKey};
use crate::utils::parser::length;

use anyhow::{Error, Result};
use anyhow::Result;
use p256::ecdh::EphemeralSecret;
use p256::PublicKey;
use rand::Rng;
use serde_json::Value;

const STOP_FLAG: [u8; 4] = u32::MIN.to_be_bytes();

pub struct ACK {
sequence: u32,
}

impl ACK {
pub fn new() -> Self {
Self {
sequence: rand::thread_rng().gen_range(1000..=9999),
}
}

pub async fn from_stream(&mut self, stream: &mut Socket) -> Result<Self> {
Ok(Self {
sequence: stream.recv_u32().await?,
})
}

pub async fn to_stream(&mut self, stream: &mut Socket) -> Result<()> {
stream.send(&self.plain_data()).await?;
Ok(())
}

pub fn plain_data(&mut self) -> [u8; 4] {
self.sequence.to_be_bytes()
}
}

pub struct OSC {
pub status_code: u32,
}
Expand All @@ -50,12 +22,12 @@ impl OSC {
Self { status_code }
}

pub async fn from_stream(stream: &mut Socket) -> Result<Self> {
pub async fn from_stream(stream: &Socket) -> Result<Self> {
let status_code = stream.recv_u32().await?;
Ok(Self { status_code })
}

pub async fn to_stream(&mut self, stream: &mut Socket) -> Result<()> {
pub async fn to_stream(&mut self, stream: &Socket) -> Result<()> {
stream.send(&self.plain_data()).await?;
Ok(())
}
Expand Down Expand Up @@ -99,7 +71,7 @@ impl<'a> OKE<'a> {
Ok(self)
}

pub async fn from_stream(&mut self, stream: &mut Socket) -> Result<&mut Self> {
pub async fn from_stream(&mut self, stream: &Socket) -> Result<&mut Self> {
let remote_public_key_length = stream.recv_usize().await?;
let remote_public_key_bytes = stream.recv(remote_public_key_length).await?;
self.remote_public_key = Some(PublicKey::from_sec1_bytes(&remote_public_key_bytes)?);
Expand All @@ -111,7 +83,7 @@ impl<'a> OKE<'a> {
Ok(self)
}

pub async fn from_stream_with_salt(&mut self, stream: &mut Socket) -> Result<&mut Self> {
pub async fn from_stream_with_salt(&mut self, stream: &Socket) -> Result<&mut Self> {
let remote_public_key_length = stream.recv_usize().await?;
let remote_public_key_bytes = stream.recv(remote_public_key_length).await?;
self.remote_public_key = Some(PublicKey::from_sec1_bytes(&remote_public_key_bytes)?);
Expand All @@ -125,12 +97,12 @@ impl<'a> OKE<'a> {
Ok(self)
}

pub async fn to_stream(&mut self, stream: &mut Socket) -> Result<()> {
pub async fn to_stream(&mut self, stream: &Socket) -> Result<()> {
stream.send(&self.plain_data()?).await?;
Ok(())
}

pub async fn to_stream_with_salt(&mut self, stream: &mut Socket) -> Result<()> {
pub async fn to_stream_with_salt(&mut self, stream: &Socket) -> Result<()> {
stream.send(&self.plain_data()?).await?;
stream.send(&self.plain_salt()?).await?;
Ok(())
Expand Down Expand Up @@ -176,10 +148,7 @@ impl OED {
}
}

pub fn from_json_or_string(
&mut self,
json_or_str: String,
) -> Result<&mut Self, Exception> {
pub fn from_json_or_string(&mut self, json_or_str: String) -> Result<&mut Self, Exception> {
let (encrypted_data, tag, nonce) =
encrypt_plaintext(json_or_str, &self.aes_key.as_ref().unwrap())?;
(self.encrypted_data, self.tag, self.nonce) =
Expand Down Expand Up @@ -207,110 +176,63 @@ impl OED {
Ok(self)
}

pub async fn from_stream(
&mut self,
stream: &mut Socket,
total_attemps: u32,
) -> Result<&mut Self> {
let mut attemp = 0;
let mut ack = false;

while attemp < total_attemps {
let mut ack_packet = ACK::new();
let mut ack_packet = ack_packet.from_stream(stream).await?;

let len_nonce = stream.recv_usize().await?;
let len_tag = stream.recv_usize().await?;

self.nonce = Some(stream.recv(len_nonce).await?);
self.tag = Some(stream.recv(len_tag).await?);

let mut encrypted_data: Vec<u8> = Vec::new();
self.chunk_count = 0;

loop {
let prefix = stream.recv_usize().await?;
if prefix == 0 {
self.encrypted_data = Some(encrypted_data);
break;
}

let mut add: Vec<u8> = Vec::new();
while add.len() != prefix {
add.extend(stream.recv(prefix - add.len()).await?)
}

encrypted_data.extend(add);
self.chunk_count += 1;
}
pub async fn from_stream(&mut self, stream: &Socket) -> Result<&mut Self> {
let len_nonce = stream.recv_usize().await?;
let len_tag = stream.recv_usize().await?;

match decrypt_bytes(
self.encrypted_data.clone().unwrap(),
self.tag.as_ref().unwrap(),
self.aes_key.as_ref().unwrap(),
self.nonce.as_ref().unwrap(),
) {
Ok(data) => {
self.data = Some(data);
ack_packet.to_stream(stream).await?;
ack = true;
break;
}
Err(error) => {
stream.send(&STOP_FLAG).await?;
eprintln!("An error occured: {error}\nRetried {attemp} times.");
attemp += 1;
continue;
}
}
}
if !ack {
stream.close().await?;
return Err(Error::from(Exception::AllAttemptsRetryFailed {
times: total_attemps,
}));
}
self.nonce = Some(stream.recv(len_nonce).await?);
self.tag = Some(stream.recv(len_tag).await?);

Ok(self)
}
let mut encrypted_data: Vec<u8> = Vec::new();
self.chunk_count = 0;

pub async fn to_stream(&mut self, stream: &mut Socket, total_attemps: u32) -> Result<()> {
let attemp = 0;
let mut ack = false;
loop {
let prefix = stream.recv_usize().await?;
if prefix == 0 {
self.encrypted_data = Some(encrypted_data);
break;
}

while attemp <= total_attemps {
let mut ack_packet = ACK::new();
ack_packet.to_stream(stream).await?;
let mut add: Vec<u8> = Vec::new();
while add.len() != prefix {
add.extend(stream.recv(prefix - add.len()).await?)
}

stream.send(&self.plain_data()?).await?;
encrypted_data.extend(add);
self.chunk_count += 1;
}

self.chunk_count = 0;
let encrypted_data = self.encrypted_data.as_ref().unwrap();
let mut remaining_data = &encrypted_data[..];
while !remaining_data.is_empty() {
let chunk_size = remaining_data.len().min(2048);
match decrypt_bytes(
self.encrypted_data.clone().unwrap(),
self.tag.as_ref().unwrap(),
self.aes_key.as_ref().unwrap(),
self.nonce.as_ref().unwrap(),
) {
Ok(data) => {
self.data = Some(data);
Ok(self)
}
Err(error) => Err(Exception::DecryptError { error }.into()),
}
}

let chunk_length = chunk_size as u32;
pub async fn to_stream(&mut self, stream: &Socket) -> Result<()> {
stream.send(&self.plain_data()?).await?;

stream.send(&chunk_length.to_be_bytes()).await?;
stream.send(&remaining_data[..chunk_size]).await?;
self.chunk_count = 0;
let encrypted_data = self.encrypted_data.as_ref().unwrap();
let mut remaining_data = &encrypted_data[..];
while !remaining_data.is_empty() {
let chunk_size = remaining_data.len().min(2048);

remaining_data = &remaining_data[chunk_size..];
}
stream.send(&STOP_FLAG).await?;
let chunk_length = chunk_size as u32;

if ack_packet.sequence == stream.recv_u32().await? {
ack = true;
break;
}
}
stream.send(&chunk_length.to_be_bytes()).await?;
stream.send(&remaining_data[..chunk_size]).await?;

if !ack {
stream.close().await?;
return Err(Error::from(Exception::AllAttemptsRetryFailed {
times: total_attemps,
}));
remaining_data = &remaining_data[chunk_size..];
}
stream.send(&STOP_FLAG).await?;

Ok(())
}
Expand Down
Loading

0 comments on commit 6c103d4

Please sign in to comment.