Skip to content

Commit

Permalink
feat: forbid invalid invoice state transitions
Browse files Browse the repository at this point in the history
  • Loading branch information
michael1011 committed Aug 21, 2024
1 parent 95b545d commit 32f4c4d
Show file tree
Hide file tree
Showing 9 changed files with 358 additions and 54 deletions.
54 changes: 45 additions & 9 deletions src/database/helpers/invoice_helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,25 @@ use diesel::{QueryDsl, RunQueryDsl, SelectableHelper};
pub trait InvoiceHelper {
fn insert(&self, invoice: &InvoiceInsertable) -> Result<usize>;
fn insert_htlc(&self, htlc: &HtlcInsertable) -> Result<usize>;
fn set_invoice_state(&self, id: i64, state: InvoiceState) -> Result<usize>;
fn set_invoice_state(
&self,
id: i64,
state: InvoiceState,
new_state: InvoiceState,
) -> Result<usize>;
fn set_invoice_preimage(&self, id: i64, preimage: &[u8]) -> Result<usize>;
fn set_htlc_state_by_id(&self, htlc_id: i64, state: InvoiceState) -> Result<usize>;
fn set_htlc_states_by_invoice(&self, invoice_id: i64, state: InvoiceState) -> Result<usize>;
fn set_htlc_state_by_id(
&self,
htlc_id: i64,
state: InvoiceState,
new_state: InvoiceState,
) -> Result<usize>;
fn set_htlc_states_by_invoice(
&self,
invoice_id: i64,
state: InvoiceState,
new_state: InvoiceState,
) -> Result<usize>;
fn get_all(&self) -> Result<Vec<HoldInvoice>>;
fn get_paginated(&self, index_start: i64, limit: u64) -> Result<Vec<HoldInvoice>>;
fn get_by_payment_hash(&self, payment_hash: &[u8]) -> Result<Option<HoldInvoice>>;
Expand Down Expand Up @@ -43,10 +58,17 @@ impl InvoiceHelper for InvoiceHelperDatabase {
.execute(&mut self.pool.get()?)?)
}

fn set_invoice_state(&self, id: i64, state: InvoiceState) -> Result<usize> {
fn set_invoice_state(
&self,
id: i64,
state: InvoiceState,
new_state: InvoiceState,
) -> Result<usize> {
state.validate_transition(new_state)?;

Ok(update(invoices::dsl::invoices)
.filter(invoices::dsl::id.eq(id))
.set(invoices::dsl::state.eq(state.to_string()))
.set(invoices::dsl::state.eq(new_state.to_string()))
.execute(&mut self.pool.get()?)?)
}

Expand All @@ -57,17 +79,31 @@ impl InvoiceHelper for InvoiceHelperDatabase {
.execute(&mut self.pool.get()?)?)
}

fn set_htlc_state_by_id(&self, htlc_id: i64, state: InvoiceState) -> Result<usize> {
fn set_htlc_state_by_id(
&self,
htlc_id: i64,
state: InvoiceState,
new_state: InvoiceState,
) -> Result<usize> {
state.validate_transition(new_state)?;

Ok(update(htlcs::dsl::htlcs)
.filter(htlcs::dsl::id.eq(htlc_id))
.set(htlcs::dsl::state.eq(state.to_string()))
.set(htlcs::dsl::state.eq(new_state.to_string()))
.execute(&mut self.pool.get()?)?)
}

fn set_htlc_states_by_invoice(&self, invoice_id: i64, state: InvoiceState) -> Result<usize> {
fn set_htlc_states_by_invoice(
&self,
invoice_id: i64,
state: InvoiceState,
new_state: InvoiceState,
) -> Result<usize> {
state.validate_transition(new_state)?;

Ok(update(htlcs::dsl::htlcs)
.filter(htlcs::dsl::invoice_id.eq(invoice_id))
.set(htlcs::dsl::state.eq(state.to_string()))
.set(htlcs::dsl::state.eq(new_state.to_string()))
.execute(&mut self.pool.get()?)?)
}

Expand Down
171 changes: 167 additions & 4 deletions src/database/model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use diesel::internal::derives::multiconnection::chrono;
use diesel::{AsChangeset, Associations, Identifiable, Insertable, Queryable, Selectable};
use serde::Serialize;
use std::error::Error;
use std::fmt::{Display, Formatter};

#[derive(Queryable, Identifiable, Selectable, AsChangeset, Serialize, Debug, PartialEq, Clone)]
Expand Down Expand Up @@ -56,6 +57,42 @@ pub struct HtlcInsertable {
pub msat: i64,
}

#[derive(Debug, PartialEq)]
pub enum StateTransitionError {
IsFinal(InvoiceState),
InvalidTransition(InvoiceState, InvoiceState),
}

impl Display for StateTransitionError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match *self {
StateTransitionError::IsFinal(state) => write!(f, "state {} is final", state),
StateTransitionError::InvalidTransition(old, new) => {
write!(f, "invoice state transition ({} -> {})", old, new)
}
}
}
}

impl Error for StateTransitionError {}

#[derive(Debug, PartialEq)]
pub enum InvoiceStateParsingError {
InvalidInvariant(String),
}

impl Display for InvoiceStateParsingError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
InvoiceStateParsingError::InvalidInvariant(state) => {
write!(f, "invalid invoice state invariant: {}", state)
}
}
}
}

impl Error for InvoiceStateParsingError {}

#[derive(Debug, PartialEq, Clone, Copy)]
pub enum InvoiceState {
Paid = 0,
Expand Down Expand Up @@ -83,23 +120,55 @@ impl From<InvoiceState> for String {
}

impl TryFrom<&str> for InvoiceState {
type Error = &'static str;
type Error = InvoiceStateParsingError;

fn try_from(value: &str) -> Result<Self, Self::Error> {
match value {
"paid" => Ok(InvoiceState::Paid),
"unpaid" => Ok(InvoiceState::Unpaid),
"accepted" => Ok(InvoiceState::Accepted),
"cancelled" => Ok(InvoiceState::Cancelled),
&_ => Err("unknown state invariant"),
&_ => Err(InvoiceStateParsingError::InvalidInvariant(
value.to_string(),
)),
}
}
}

impl TryFrom<&String> for InvoiceState {
type Error = InvoiceStateParsingError;

fn try_from(value: &String) -> Result<Self, Self::Error> {
InvoiceState::try_from(value.as_str())
}
}

impl InvoiceState {
pub fn is_final(&self) -> bool {
*self == InvoiceState::Paid || *self == InvoiceState::Cancelled
}

pub fn validate_transition(&self, new_state: InvoiceState) -> Result<(), StateTransitionError> {
if self.is_final() {
return Err(StateTransitionError::IsFinal(*self));
}

match *self {
InvoiceState::Unpaid => {
if new_state != InvoiceState::Accepted && new_state != InvoiceState::Cancelled {
return Err(StateTransitionError::InvalidTransition(*self, new_state));
}
}
InvoiceState::Accepted => {
if new_state == InvoiceState::Unpaid {
return Err(StateTransitionError::InvalidTransition(*self, new_state));
}
}
_ => {}
};

Ok(())
}
}

#[derive(Serialize, Clone, Debug)]
Expand Down Expand Up @@ -134,7 +203,9 @@ impl HoldInvoice {

#[cfg(test)]
mod test {
use crate::database::model::{HoldInvoice, Htlc, Invoice, InvoiceState};
use crate::database::model::{
HoldInvoice, Htlc, Invoice, InvoiceState, InvoiceStateParsingError, StateTransitionError,
};

#[test]
fn invoice_state_to_string() {
Expand Down Expand Up @@ -162,7 +233,34 @@ mod test {

assert_eq!(
InvoiceState::try_from("invalid").err().unwrap(),
"unknown state invariant"
InvoiceStateParsingError::InvalidInvariant("invalid".to_string())
);
}

#[test]
fn invoice_state_from_string() {
assert_eq!(
InvoiceState::try_from(&String::from("paid")).unwrap(),
InvoiceState::Paid
);
assert_eq!(
InvoiceState::try_from(&String::from("unpaid")).unwrap(),
InvoiceState::Unpaid
);
assert_eq!(
InvoiceState::try_from(&String::from("accepted")).unwrap(),
InvoiceState::Accepted
);
assert_eq!(
InvoiceState::try_from(&String::from("cancelled")).unwrap(),
InvoiceState::Cancelled
);

assert_eq!(
InvoiceState::try_from(&String::from("invalid"))
.err()
.unwrap(),
InvoiceStateParsingError::InvalidInvariant("invalid".to_string())
);
}

Expand All @@ -175,6 +273,71 @@ mod test {
assert!(!InvoiceState::Accepted.is_final());
}

#[test]
fn invoice_state_validate() {
assert_eq!(
InvoiceState::Unpaid
.validate_transition(InvoiceState::Accepted)
.unwrap(),
(),
);
assert_eq!(
InvoiceState::Unpaid
.validate_transition(InvoiceState::Cancelled)
.unwrap(),
()
);

assert_eq!(
InvoiceState::Accepted
.validate_transition(InvoiceState::Paid)
.unwrap(),
(),
);
assert_eq!(
InvoiceState::Unpaid
.validate_transition(InvoiceState::Cancelled)
.unwrap(),
()
);
}

#[test]
fn invoice_state_validate_transition_final() {
assert_eq!(
InvoiceState::Paid
.validate_transition(InvoiceState::Accepted)
.err()
.unwrap(),
StateTransitionError::IsFinal(InvoiceState::Paid)
);
assert_eq!(
InvoiceState::Cancelled
.validate_transition(InvoiceState::Accepted)
.err()
.unwrap(),
StateTransitionError::IsFinal(InvoiceState::Cancelled)
);
}

#[test]
fn invoice_state_validate_transition_invalid() {
assert_eq!(
InvoiceState::Unpaid
.validate_transition(InvoiceState::Paid)
.err()
.unwrap(),
StateTransitionError::InvalidTransition(InvoiceState::Unpaid, InvoiceState::Paid)
);
assert_eq!(
InvoiceState::Accepted
.validate_transition(InvoiceState::Unpaid)
.err()
.unwrap(),
StateTransitionError::InvalidTransition(InvoiceState::Accepted, InvoiceState::Unpaid)
);
}

#[test]
fn hold_invoice_amount_paid_msat() {
let mut invoice = HoldInvoice::new(
Expand Down
21 changes: 18 additions & 3 deletions src/grpc/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,25 @@ mod test {
impl InvoiceHelper for InvoiceHelper {
fn insert(&self, invoice: &InvoiceInsertable) -> Result<usize>;
fn insert_htlc(&self, htlc: &HtlcInsertable) -> Result<usize>;
fn set_invoice_state(&self, id: i64, state: InvoiceState) -> Result<usize>;
fn set_invoice_state(
&self,
id: i64,
state: InvoiceState,
new_state: InvoiceState,
) -> Result<usize>;
fn set_invoice_preimage(&self, id: i64, preimage: &[u8]) -> Result<usize>;
fn set_htlc_state_by_id(&self, htlc_id: i64, state: InvoiceState) -> Result<usize>;
fn set_htlc_states_by_invoice(&self, invoice_id: i64, state: InvoiceState) -> Result<usize>;
fn set_htlc_state_by_id(
&self,
htlc_id: i64,
state: InvoiceState,
new_state: InvoiceState,
) -> Result<usize>;
fn set_htlc_states_by_invoice(
&self,
invoice_id: i64,
state: InvoiceState,
new_state: InvoiceState,
) -> Result<usize>;
fn get_all(&self) -> Result<Vec<HoldInvoice>>;
fn get_paginated(&self, index_start: i64, limit: u64) -> Result<Vec<HoldInvoice>>;
fn get_by_payment_hash(&self, payment_hash: &[u8]) -> Result<Option<HoldInvoice>>;
Expand Down
25 changes: 20 additions & 5 deletions src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,25 @@ mod test {
impl InvoiceHelper for InvoiceHelper {
fn insert(&self, invoice: &InvoiceInsertable) -> Result<usize>;
fn insert_htlc(&self, htlc: &HtlcInsertable) -> Result<usize>;
fn set_invoice_state(&self, id: i64, state: InvoiceState) -> Result<usize>;
fn set_invoice_state(
&self,
id: i64,
state: InvoiceState,
new_state: InvoiceState,
) -> Result<usize>;
fn set_invoice_preimage(&self, id: i64, preimage: &[u8]) -> Result<usize>;
fn set_htlc_state_by_id(&self, htlc_id: i64, state: InvoiceState) -> Result<usize>;
fn set_htlc_states_by_invoice(&self, invoice_id: i64, state: InvoiceState) -> Result<usize>;
fn set_htlc_state_by_id(
&self,
htlc_id: i64,
state: InvoiceState,
new_state: InvoiceState,
) -> Result<usize>;
fn set_htlc_states_by_invoice(
&self,
invoice_id: i64,
state: InvoiceState,
new_state: InvoiceState,
) -> Result<usize>;
fn get_all(&self) -> Result<Vec<HoldInvoice>>;
fn get_paginated(&self, index_start: i64, limit: u64) -> Result<Vec<HoldInvoice>>;
fn get_by_payment_hash(&self, payment_hash: &[u8]) -> Result<Option<HoldInvoice>>;
Expand Down Expand Up @@ -548,10 +563,10 @@ mod test {
});
helper_settler
.expect_set_htlc_states_by_invoice()
.returning(|_, _| Ok(0));
.returning(|_, _, _| Ok(0));
helper_settler
.expect_set_invoice_state()
.returning(|_, _| Ok(0));
.returning(|_, _, _| Ok(0));
helper_settler
.expect_set_invoice_preimage()
.returning(|_, _| Ok(0));
Expand Down
Loading

0 comments on commit 32f4c4d

Please sign in to comment.