Skip to content

Commit

Permalink
pytest: add spendable tests for askrene.
Browse files Browse the repository at this point in the history
Make sure we're not exceeding the spendable amount of a local channel.

Signed-off-by: Rusty Russell <[email protected]>
  • Loading branch information
rustyrussell authored and ShahanaFarooqui committed Aug 19, 2024
1 parent 7fb7234 commit b99fd02
Showing 1 changed file with 74 additions and 1 deletion.
75 changes: 74 additions & 1 deletion tests/test_askrene.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pyln.client import RpcError
from utils import (
only_one, first_scid, GenChannel, generate_gossip_store,
TEST_NETWORK
TEST_NETWORK, sync_blockheight, wait_for
)
import os
import pytest
Expand Down Expand Up @@ -415,3 +415,76 @@ def test_getroutes_auto_localchans(node_factory):
paths=[[{'short_channel_id': scid12, 'amount_msat': 102010, 'delay': 99 + 6 + 6},
{'short_channel_id': '0x1x0', 'amount_msat': 102010, 'delay': 99 + 6 + 6},
{'short_channel_id': '1x2x1', 'amount_msat': 101000, 'delay': 99 + 6}]])


def test_fees_dont_exceed_constraints(node_factory):
l1 = node_factory.get_node(start=False)

msat = 100000000
max_msat = int(msat * 0.45)
# 0 has to use two paths (1 and 2) to reach 3. But we tell it 0->1 has limited capacity.
gsfile, nodemap = generate_gossip_store([GenChannel(0, 1, capacity_sats=msat // 1000, forward=GenChannel.Half(propfee=10000)),
GenChannel(0, 2, capacity_sats=msat // 1000, forward=GenChannel.Half(propfee=10000)),
GenChannel(1, 3, capacity_sats=msat // 1000, forward=GenChannel.Half(propfee=10000)),
GenChannel(2, 3, capacity_sats=msat // 1000, forward=GenChannel.Half(propfee=10000))])

# Set up l1 with this as the gossip_store
shutil.copy(gsfile.name, os.path.join(l1.daemon.lightning_dir, TEST_NETWORK, 'gossip_store'))
l1.start()

chan = only_one([c for c in l1.rpc.listchannels(source=nodemap[0])['channels'] if c['destination'] == nodemap[1]])
l1.rpc.askrene_inform_channel(layer='test_layers',
short_channel_id=chan['short_channel_id'],
direction=chan['direction'],
maximum_msat=max_msat)

routes = l1.rpc.getroutes(source=nodemap[0],
destination=nodemap[3],
amount_msat=msat,
layers=['test_layers'],
maxfee_msat=msat,
final_cltv=99)['routes']
assert len(routes) == 2
for hop in routes[0]['path'] + routes[1]['path']:
if hop['short_channel_id'] == chan['short_channel_id']:
amount = hop['amount_msat']
assert amount <= max_msat


def test_mpp_pay2(node_factory, bitcoind):
l1, l2, l3 = node_factory.get_nodes(3)
l1.fundwallet(10_000_000)
l2.fundwallet(10_000_000)
l1.rpc.connect(l2.info['id'], 'localhost', port=l2.port)
l2.rpc.connect(l3.info['id'], 'localhost', port=l3.port)

capacities = (100_000, 100_000, 200_000, 300_000, 400_000)
for capacity in capacities:
l1.rpc.fundchannel(l2.info["id"], capacity, mindepth=1)
l2.rpc.fundchannel(l3.info["id"], capacity, mindepth=1)

bitcoind.generate_block(1, wait_for_mempool=2)
sync_blockheight(bitcoind, [l1, l2])

bitcoind.generate_block(5)
wait_for(lambda: len(l1.rpc.listchannels()["channels"]) == 2 * 2 * len(capacities))

routes = l1.rpc.getroutes(
source=l1.info["id"],
destination=l3.info["id"],
amount_msat=800_000_000,
layers=["auto.localchans", "auto.sourcefree"],
maxfee_msat=50_000_000,
final_cltv=10,
)

# Don't exceed spendable_msat
maxes = {}
for chan in l1.rpc.listpeerchannels()['channels']:
maxes["{}/{}".format(chan['short_channel_id'], chan['direction'])] = chan['spendable_msat']

for r in routes['routes']:
for p in r['path']:
scidd = "{}/{}".format(p['short_channel_id'], p['direction'])
if scidd in maxes:
assert p['amount_msat'] <= maxes[scidd]

0 comments on commit b99fd02

Please sign in to comment.