diff --git a/protos/hold.proto b/protos/hold.proto index 4a80d07..fb8b700 100644 --- a/protos/hold.proto +++ b/protos/hold.proto @@ -111,7 +111,10 @@ message TrackResponse { InvoiceState state = 1; } -message TrackAllRequest {} +message TrackAllRequest { + repeated bytes payment_hashes = 1; +} + message TrackAllResponse { bytes payment_hash = 1; string bolt11 = 2; diff --git a/src/grpc/service.rs b/src/grpc/service.rs index a9276d3..a26bf86 100644 --- a/src/grpc/service.rs +++ b/src/grpc/service.rs @@ -12,7 +12,7 @@ use crate::grpc::service::hold::{ use crate::grpc::transformers::{transform_invoice_state, transform_route_hints}; use crate::settler::Settler; use bitcoin::hashes::{sha256, Hash}; -use log::{debug, error}; +use log::{debug, error, warn}; use std::pin::Pin; use tokio::sync::mpsc; use tonic::codegen::tokio_stream::wrappers::ReceiverStream; @@ -270,13 +270,69 @@ where async fn track_all( &self, - _: Request, + request: Request, ) -> Result, Status> { + let params = request.into_inner(); + let (tx, rx) = mpsc::channel(128); + let invoice_helper = self.invoice_helper.clone(); let mut state_rx = self.settler.state_rx(); tokio::spawn(async move { + for hash in params.payment_hashes { + let invoice = match invoice_helper.get_by_payment_hash(&hash) { + Ok(invoice) => match invoice { + Some(invoice) => invoice, + None => { + warn!( + "Could not find invoice with payment hash: {}", + hex::encode(&hash) + ); + continue; + } + }, + Err(err) => { + let err = format!( + "Could not get invoice with payment hash {}: {}", + hex::encode(&hash), + err + ); + error!("{}", err); + let _ = tx.send(Err(Status::new(Code::Internal, err))).await; + return; + } + }; + + let state = transform_invoice_state( + match InvoiceState::try_from(invoice.invoice.state.as_str()) { + Ok(state) => state, + Err(err) => { + let err = format!( + "Could not parse state of invoice {}: {}", + hex::encode(&hash), + err + ); + error!("{}", err); + let _ = tx.send(Err(Status::new(Code::Internal, err))).await; + return; + } + }, + ); + + if let Err(err) = tx + .send(Ok(TrackAllResponse { + state, + bolt11: invoice.invoice.bolt11, + payment_hash: invoice.invoice.payment_hash, + })) + .await + { + error!("Could not send invoice state: {}", err); + return; + }; + } + loop { match state_rx.recv().await { Ok(update) => { diff --git a/tests-regtest/hold/regtest_grpc.py b/tests-regtest/hold/regtest_grpc.py index 128b5ea..e9fcf78 100644 --- a/tests-regtest/hold/regtest_grpc.py +++ b/tests-regtest/hold/regtest_grpc.py @@ -377,3 +377,47 @@ def track_states() -> list[tuple[bytes, str, str]]: (payment_hash_settled, invoice_settled.bolt11, InvoiceState.ACCEPTED), (payment_hash_settled, invoice_settled.bolt11, InvoiceState.PAID), ] + + def test_track_all_existing(self, cl: HoldStub) -> None: + expected_events = 3 + + (_, payment_hash_not_found) = new_preimage_bytes() + (preimage_settled, payment_hash_settled) = new_preimage_bytes() + + def track_states() -> list[tuple[bytes, str, str]]: + evs = [] + + sub = cl.TrackAll( + TrackAllRequest( + payment_hashes=[payment_hash_not_found, payment_hash_settled] + ) + ) + for ev in sub: + evs.append((ev.payment_hash, ev.bolt11, ev.state)) + if len(evs) == expected_events: + sub.cancel() + break + + return evs + + with concurrent.futures.ThreadPoolExecutor() as pool: + invoice_settled: InvoiceResponse = cl.Invoice( + InvoiceRequest(payment_hash=payment_hash_settled, amount_msat=1_000) + ) + + fut = pool.submit(track_states) + + pay = LndPay(1, invoice_settled.bolt11) + pay.start() + time.sleep(1) + + cl.Settle(SettleRequest(payment_preimage=preimage_settled)) + pay.join() + + res = fut.result() + assert len(res) == expected_events + assert res == [ + (payment_hash_settled, invoice_settled.bolt11, InvoiceState.UNPAID), + (payment_hash_settled, invoice_settled.bolt11, InvoiceState.ACCEPTED), + (payment_hash_settled, invoice_settled.bolt11, InvoiceState.PAID), + ]