Skip to content

Commit

Permalink
refactor(network): simplify local state mini-protocol implementation (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
scarmuega authored Nov 10, 2023
1 parent aae7d92 commit e0f9f14
Show file tree
Hide file tree
Showing 11 changed files with 599 additions and 547 deletions.
38 changes: 22 additions & 16 deletions examples/n2c-miniprotocols/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
use pallas::network::{
facades::NodeClient,
miniprotocols::{chainsync, localstate, Point, MAINNET_MAGIC},
miniprotocols::{chainsync, localstate::queries_v16, Point, PRE_PRODUCTION_MAGIC},

Check warning on line 3 in examples/n2c-miniprotocols/src/main.rs

View workflow job for this annotation

GitHub Actions / Check (windows-latest, stable)

unused import: `PRE_PRODUCTION_MAGIC`
};
use tracing::info;

async fn do_localstate_query(client: &mut NodeClient) {

Check warning on line 7 in examples/n2c-miniprotocols/src/main.rs

View workflow job for this annotation

GitHub Actions / Check (windows-latest, stable)

function `do_localstate_query` is never used
client.statequery().acquire(None).await.unwrap();
let client = client.statequery();

let result = client
.statequery()
.query(localstate::queries::Request::GetSystemStart)
client.acquire(None).await.unwrap();

let result = queries_v16::get_chain_point(client).await.unwrap();
info!("result: {:?}", result);

let result = queries_v16::get_system_start(client).await.unwrap();
info!("result: {:?}", result);

let era = queries_v16::get_current_era(client).await.unwrap();
info!("result: {:?}", era);

let result = queries_v16::get_block_epoch_number(client, era)
.await
.unwrap();

info!("system start result: {:?}", result);
info!("result: {:?}", result);

client.send_release().await.unwrap();
}

async fn do_chainsync(client: &mut NodeClient) {

Check warning on line 30 in examples/n2c-miniprotocols/src/main.rs

View workflow job for this annotation

GitHub Actions / Check (windows-latest, stable)

function `do_chainsync` is never used
Expand Down Expand Up @@ -43,6 +54,10 @@ async fn do_chainsync(client: &mut NodeClient) {
}
}

// change the following to match the Cardano node socket in your local
// environment
const SOCKET_PATH: &str = "/tmp/node.socket";

Check warning on line 59 in examples/n2c-miniprotocols/src/main.rs

View workflow job for this annotation

GitHub Actions / Check (windows-latest, stable)

constant `SOCKET_PATH` is never used

#[cfg(target_family = "unix")]
#[tokio::main]
async fn main() {
Expand All @@ -55,15 +70,7 @@ async fn main() {

// we connect to the unix socket of the local node. Make sure you have the right
// path for your environment
let socket_path = "/tmp/node.socket";

// we connect to the unix socket of the local node and perform a handshake query
let version_table = NodeClient::handshake_query(socket_path, MAINNET_MAGIC)
.await
.unwrap();
info!("handshake query result: {:?}", version_table);

let mut client = NodeClient::connect(socket_path, MAINNET_MAGIC)
let mut client = NodeClient::connect(SOCKET_PATH, PRE_PRODUCTION_MAGIC)
.await
.unwrap();

Expand All @@ -75,7 +82,6 @@ async fn main() {
}

#[cfg(not(target_family = "unix"))]

fn main() {
panic!("can't use n2c unix socket on non-unix systems");
}
2 changes: 1 addition & 1 deletion pallas-network/src/facades.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ impl NodeServer {
plexer_handle,
version: ver,
chainsync: server_cs,
statequery: server_sq
statequery: server_sq,
})
} else {
plexer_handle.abort();
Expand Down
80 changes: 44 additions & 36 deletions pallas-network/src/miniprotocols/localstate/client.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@
use pallas_codec::utils::AnyCbor;
use std::fmt::Debug;

use pallas_codec::Fragment;

use std::marker::PhantomData;
use thiserror::*;

use super::{AcquireFailure, Message, Query, State};
use super::{AcquireFailure, Message, State};
use crate::miniprotocols::Point;
use crate::multiplexer;

#[derive(Error, Debug)]
pub enum ClientError {
#[error("attempted to receive message while agency is ours")]
AgencyIsOurs,

#[error("attempted to send message while agency is theirs")]
AgencyIsTheirs,

#[error("inbound message is not valid for current state")]
InvalidInbound,

#[error("outbound message is not valid for current state")]
InvalidOutbound,

#[error("failure acquiring point, not found")]
AcquirePointNotFound,

#[error("failure acquiring point, too old")]
AcquirePointTooOld,

#[error("failure decoding CBOR data")]
InvalidCbor(pallas_codec::minicbor::decode::Error),

#[error("error while sending or receiving data through the channel")]
Plexer(multiplexer::Error),
}
Expand All @@ -36,22 +42,11 @@ impl From<AcquireFailure> for ClientError {
}
}

pub struct GenericClient<Q>(State, multiplexer::ChannelBuffer, PhantomData<Q>)
where
Q: Query,
Message<Q>: Fragment;
pub struct GenericClient(State, multiplexer::ChannelBuffer);

impl<Q> GenericClient<Q>
where
Q: Query,
Message<Q>: Fragment,
{
impl GenericClient {
pub fn new(channel: multiplexer::AgentChannel) -> Self {
Self(
State::Idle,
multiplexer::ChannelBuffer::new(channel),
PhantomData {},
)
Self(State::Idle, multiplexer::ChannelBuffer::new(channel))
}

pub fn state(&self) -> &State {
Expand Down Expand Up @@ -87,7 +82,7 @@ where
}
}

fn assert_outbound_state(&self, msg: &Message<Q>) -> Result<(), ClientError> {
fn assert_outbound_state(&self, msg: &Message) -> Result<(), ClientError> {
match (&self.0, msg) {
(State::Idle, Message::Acquire(_)) => Ok(()),
(State::Idle, Message::Done) => Ok(()),
Expand All @@ -98,7 +93,7 @@ where
}
}

fn assert_inbound_state(&self, msg: &Message<Q>) -> Result<(), ClientError> {
fn assert_inbound_state(&self, msg: &Message) -> Result<(), ClientError> {
match (&self.0, msg) {
(State::Acquiring, Message::Acquired) => Ok(()),
(State::Acquiring, Message::Failure(_)) => Ok(()),
Expand All @@ -107,15 +102,18 @@ where
}
}

pub async fn send_message(&mut self, msg: &Message<Q>) -> Result<(), ClientError> {
pub async fn send_message(&mut self, msg: &Message) -> Result<(), ClientError> {
self.assert_agency_is_ours()?;
self.assert_outbound_state(msg)?;
self.1.send_msg_chunks(msg).await.map_err(ClientError::Plexer)?;
self.1
.send_msg_chunks(msg)
.await
.map_err(ClientError::Plexer)?;

Ok(())
}

pub async fn recv_message(&mut self) -> Result<Message<Q>, ClientError> {
pub async fn recv_message(&mut self) -> Result<Message, ClientError> {
self.assert_agency_is_theirs()?;
let msg = self.1.recv_full_msg().await.map_err(ClientError::Plexer)?;
self.assert_inbound_state(&msg)?;
Expand All @@ -124,31 +122,31 @@ where
}

pub async fn send_acquire(&mut self, point: Option<Point>) -> Result<(), ClientError> {
let msg = Message::<Q>::Acquire(point);
let msg = Message::Acquire(point);
self.send_message(&msg).await?;
self.0 = State::Acquiring;

Ok(())
}

pub async fn send_reacquire(&mut self, point: Option<Point>) -> Result<(), ClientError> {
let msg = Message::<Q>::ReAcquire(point);
let msg = Message::ReAcquire(point);
self.send_message(&msg).await?;
self.0 = State::Acquiring;

Ok(())
}

pub async fn send_release(&mut self) -> Result<(), ClientError> {
let msg = Message::<Q>::Release;
let msg = Message::Release;
self.send_message(&msg).await?;
self.0 = State::Idle;

Ok(())
}

pub async fn send_done(&mut self) -> Result<(), ClientError> {
let msg = Message::<Q>::Done;
let msg = Message::Done;
self.send_message(&msg).await?;
self.0 = State::Done;

Expand All @@ -174,28 +172,38 @@ where
self.recv_while_acquiring().await
}

pub async fn send_query(&mut self, request: Q::Request) -> Result<(), ClientError> {
let msg = Message::<Q>::Query(request);
pub async fn send_query(&mut self, request: AnyCbor) -> Result<Message, ClientError> {
let msg = Message::Query(request);
self.send_message(&msg).await?;
self.0 = State::Querying;

Ok(())
Ok(msg)
}

pub async fn recv_while_querying(&mut self) -> Result<Q::Response, ClientError> {
pub async fn recv_while_querying(&mut self) -> Result<AnyCbor, ClientError> {
match self.recv_message().await? {
Message::Result(x) => {
Message::Result(result) => {
self.0 = State::Acquired;
Ok(x)
Ok(result)
}
_ => Err(ClientError::InvalidInbound),
}
}

pub async fn query(&mut self, request: Q::Request) -> Result<Q::Response, ClientError> {
pub async fn query_any(&mut self, request: AnyCbor) -> Result<AnyCbor, ClientError> {
self.send_query(request).await?;
self.recv_while_querying().await
}

pub async fn query<Q, R>(&mut self, request: Q) -> Result<R, ClientError>
where
Q: pallas_codec::minicbor::Encode<()>,
for<'b> R: pallas_codec::minicbor::Decode<'b, ()>,
{
let request = AnyCbor::from_encode(request);
let response = self.query_any(request).await?;
response.into_decode().map_err(ClientError::InvalidCbor)
}
}

pub type Client = GenericClient<super::queries::QueryV16>;
pub type Client = GenericClient;
16 changes: 3 additions & 13 deletions pallas-network/src/miniprotocols/localstate/codec.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use pallas_codec::minicbor::{decode, encode, Decode, Encode, Encoder};

use super::{AcquireFailure, Message, Query};
use super::{AcquireFailure, Message};

impl Encode<()> for AcquireFailure {
fn encode<W: encode::Write>(
Expand Down Expand Up @@ -36,12 +36,7 @@ impl<'b> Decode<'b, ()> for AcquireFailure {
}
}

impl<Q> Encode<()> for Message<Q>
where
Q: Query,
Q::Request: Encode<()>,
Q::Response: Encode<()>,
{
impl Encode<()> for Message {
fn encode<W: encode::Write>(
&self,
e: &mut Encoder<W>,
Expand Down Expand Up @@ -97,12 +92,7 @@ where
}
}

impl<'b, Q> Decode<'b, ()> for Message<Q>
where
Q: Query,
Q::Request: Decode<'b, ()>,
Q::Response: Decode<'b, ()>,
{
impl<'b> Decode<'b, ()> for Message {
fn decode(
d: &mut pallas_codec::minicbor::Decoder<'b>,
_ctx: &mut (),
Expand Down
3 changes: 2 additions & 1 deletion pallas-network/src/miniprotocols/localstate/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
mod client;
mod codec;
mod protocol;
pub mod queries;
mod server;

pub mod queries_v16;

pub use client::*;
pub use codec::*;
pub use protocol::*;
Expand Down
13 changes: 5 additions & 8 deletions pallas-network/src/miniprotocols/localstate/protocol.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::fmt::Debug;

use pallas_codec::utils::AnyCbor;

use crate::miniprotocols::Point;

#[derive(Debug, PartialEq, Eq, Clone)]
Expand All @@ -17,18 +19,13 @@ pub enum AcquireFailure {
PointNotOnChain,
}

pub trait Query: Debug {
type Request: Clone + Debug;
type Response: Clone + Debug;
}

#[derive(Debug)]
pub enum Message<Q: Query> {
pub enum Message {
Acquire(Option<Point>),
Failure(AcquireFailure),
Acquired,
Query(Q::Request),
Result(Q::Response),
Query(AnyCbor),
Result(AnyCbor),
ReAcquire(Option<Point>),
Release,
Done,
Expand Down
Loading

0 comments on commit e0f9f14

Please sign in to comment.