Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onion messages wip #9009

Closed
wants to merge 14 commits into from
70 changes: 70 additions & 0 deletions electrum/bolt12.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
#
# Electrum - lightweight Bitcoin client
# Copyright (C) 2023 The Electrum developers
#
# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import io
import time

from . import lnmsg
from .segwit_addr import bech32_decode, DecodedBech32, convertbits


def is_offer(data: str):
d = bech32_decode(data, ignore_long_length=True, with_checksum=False)
if d == DecodedBech32(None, None, None):
return False
return d.hrp == 'lno'


def decode_offer(data):
d = bech32_decode(data, ignore_long_length=True, with_checksum=False)
d = bytes(convertbits(d.data, 5, 8))
# we bomb on trailing 0, remove
while d[-1] == 0:
d = d[:-1]
f = io.BytesIO(d)
lns = lnmsg.LNSerializer()
return lns.read_tlv_stream(fd=f, tlv_stream_name='offer')


def decode_invoice_request(data):
# we bomb on trailing 0, remove
while data[-1] == 0:
data = data[:-1]
f = io.BytesIO(data)
lns = lnmsg.LNSerializer()
return lns.read_tlv_stream(fd=f, tlv_stream_name='invoice_request')


def decode_invoice(data):
# we bomb on trailing 0, remove
while data[-1] == 0:
data = data[:-1]
f = io.BytesIO(data)
lns = lnmsg.LNSerializer()
return lns.read_tlv_stream(fd=f, tlv_stream_name='invoice')


async def request_invoice(bolt12_offer):
time.sleep(5)
145 changes: 138 additions & 7 deletions electrum/lnmsg.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,18 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> U
type_len = 33
elif field_type == 'short_channel_id':
type_len = 8
elif field_type == 'sciddir_or_pubkey':
buf = fd.read(1)
if buf[0] in [0, 1]:
type_len = 9
elif buf[0] in [2, 3]:
type_len = 33
else:
raise Exception(f"invalid sciddir_or_pubkey, prefix byte not in range 0-3")
buf += fd.read(type_len - 1)
if len(buf) != type_len:
raise UnexpectedEndOfStream()
return buf

if count == "...":
total_len = -1 # read all
Expand Down Expand Up @@ -225,6 +237,14 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str],
type_len = 33
elif field_type == 'short_channel_id':
type_len = 8
elif field_type == 'sciddir_or_pubkey':
assert isinstance(value, bytes)
if value[0] in [0, 1]:
type_len = 9 # short_channel_id
elif value[0] in [2, 3]:
type_len = 33 # point
else:
raise Exception(f"invalid sciddir_or_pubkey, prefix byte not in range 0-3")
total_len = -1
if count != "...":
if type_len is None:
Expand Down Expand Up @@ -299,6 +319,8 @@ def __init__(self, *, for_onion_wire: bool = False):
self.in_tlv_stream_get_record_type_from_name = {} # type: Dict[str, Dict[str, int]]
self.in_tlv_stream_get_record_name_from_type = {} # type: Dict[str, Dict[int, str]]

self.subtypes = {} # type: Dict[str, Dict[str, Sequence[str]]]

if for_onion_wire:
path = os.path.join(os.path.dirname(__file__), "lnwire", "onion_wire.csv")
else:
Expand Down Expand Up @@ -348,9 +370,107 @@ def __init__(self, *, for_onion_wire: bool = False):
assert tlv_stream_name == row[1]
assert tlv_record_name == row[2]
self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type].append(tuple(row))
elif row[0] == "subtype":
# subtype,<subtypename>
subtypename = row[1]
assert subtypename not in self.subtypes, f"duplicate declaration of subtype {subtypename}"
self.subtypes[subtypename] = {}
elif row[0] == "subtypedata":
# subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
subtypename = row[1]
fieldname = row[2]
assert subtypename in self.subtypes, f"subtypedata definition for subtype {subtypename} declared before subtype"
assert fieldname not in self.subtypes[subtypename], f"duplicate field definition for {fieldname} for subtype {subtypename}"
self.subtypes[subtypename][fieldname] = tuple(row)
else:
pass # TODO

def _write_complex_field(self, *, fd: io.BytesIO, field_type: str, count: Union[int, str],
value: Union[List[Dict[str, Any]], Dict[str, Any]]) -> None:
assert fd
assert field_type in self.subtypes, f"unknown subtype {field_type}"

if isinstance(count, int):
assert count >= 0, f"{count!r} must be non-neg int"
elif count == "...":
pass
else:
raise Exception(f"unexpected field count: {count!r}")
if count == 0:
return

if count == 1:
assert isinstance(value, dict) or isinstance(value, list)
values = [value] if isinstance(value, dict) else value
else:
assert isinstance(value, list)
values = value

if count == '...':
count = len(values)
else:
assert count == len(values)
if count == 0:
return

for record in values:
for subtypename, row in self.subtypes[field_type].items():
# subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
subtype_field_name = row[2]
subtype_field_type = row[3]
subtype_field_count_str = row[4]

subtype_field_count = _resolve_field_count(subtype_field_count_str,
vars_dict=record,
allow_any=True)
if subtype_field_type in self.subtypes:
self._write_complex_field(fd=fd,
field_type=subtype_field_type,
count=subtype_field_count,
value=record[subtype_field_name])
else:
_write_field(fd=fd,
field_type=subtype_field_type,
count=subtype_field_count,
value=record[subtype_field_name])

def _read_complex_field(self, *, fd: io.BytesIO, field_type: str, count: Union[int, str])\
-> Union[bytes, List[Dict[str, Any]], Dict[str, Any]]:
assert fd
if isinstance(count, int):
assert count >= 0, f"{count!r} must be non-neg int"
elif count == "...":
#pass
raise Exception('reading needs defined length.')
else:
raise Exception(f"unexpected field count: {count!r}")
if count == 0:
return b""

parsedlist = [{} for x in range(count)]

for parsed in parsedlist:
for subtypename, row in self.subtypes[field_type].items():
# subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
subtype_field_name = row[2]
subtype_field_type = row[3]
subtype_field_count_str = row[4]

subtype_field_count = _resolve_field_count(subtype_field_count_str,
vars_dict=parsed,
allow_any=True)

if subtype_field_type in self.subtypes:
parsed[subtype_field_name] = self._read_complex_field(fd=fd,
field_type=subtype_field_type,
count=subtype_field_count)
else:
parsed[subtype_field_name] = _read_field(fd=fd,
field_type=subtype_field_type,
count=subtype_field_count)

return parsedlist[0] if len(parsedlist) == 1 else parsedlist

def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, **kwargs) -> None:
scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
for tlv_record_type, scheme in scheme_map.items(): # note: tlv_record_type is monotonically increasing
Expand All @@ -372,10 +492,16 @@ def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, **kwargs) ->
vars_dict=kwargs[tlv_record_name],
allow_any=True)
field_value = kwargs[tlv_record_name][field_name]
_write_field(fd=tlv_record_fd,
field_type=field_type,
count=field_count,
value=field_value)
if field_type in self.subtypes:
self._write_complex_field(fd=tlv_record_fd,
field_type=field_type,
count=field_count,
value=field_value)
else:
_write_field(fd=tlv_record_fd,
field_type=field_type,
count=field_count,
value=field_value)
else:
raise Exception(f"unexpected row in scheme: {row!r}")
_write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue())
Expand Down Expand Up @@ -417,9 +543,14 @@ def read_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str) -> Dict[str,
vars_dict=parsed[tlv_record_name],
allow_any=True)
#print(f">> count={field_count}. parsed={parsed}")
parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd,
field_type=field_type,
count=field_count)
if field_type in self.subtypes:
parsed[tlv_record_name][field_name] = self._read_complex_field(fd=tlv_record_fd,
field_type=field_type,
count=field_count)
else:
parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd,
field_type=field_type,
count=field_count)
else:
raise Exception(f"unexpected row in scheme: {row!r}")
if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
Expand Down
80 changes: 75 additions & 5 deletions electrum/lnonion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@

import io
import hashlib
from copy import deepcopy
from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING, Dict, Any, Optional, Union
from enum import IntEnum

from . import ecc
from .crypto import sha256, hmac_oneshot, chacha20_encrypt
from .crypto import sha256, hmac_oneshot, chacha20_encrypt, chacha20_poly1305_encrypt
from .ecc import ECPubkey
from .util import profiler, xor_bytes, bfh
from .lnutil import (get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH,
NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, OnionFailureCodeMetaFlag)
Expand All @@ -52,11 +54,15 @@ class InvalidOnionPubkey(Exception): pass

class OnionHopsDataSingle: # called HopData in lnd

def __init__(self, *, payload: dict = None):
def __init__(self, *, payload: dict = None, tlv_stream_name: str = 'payload', blind_fields: dict = None):
if payload is None:
payload = {}
self.payload = payload
self.hmac = None
self.tlv_stream_name = tlv_stream_name
if blind_fields is None:
blind_fields = {}
self.blind_fields = blind_fields
self._raw_bytes_payload = None # used in unit tests

def to_bytes(self) -> bytes:
Expand All @@ -68,7 +74,7 @@ def to_bytes(self) -> bytes:
# adding TLV payload. note: legacy hop data format no longer supported.
payload_fd = io.BytesIO()
OnionWireSerializer.write_tlv_stream(fd=payload_fd,
tlv_stream_name="payload",
tlv_stream_name=self.tlv_stream_name,
**self.payload)
payload_bytes = payload_fd.getvalue()
with io.BytesIO() as fd:
Expand Down Expand Up @@ -145,7 +151,7 @@ def from_bytes(cls, b: bytes):


def get_bolt04_onion_key(key_type: bytes, secret: bytes) -> bytes:
if key_type not in (b'rho', b'mu', b'um', b'ammag', b'pad'):
if key_type not in (b'rho', b'mu', b'um', b'ammag', b'pad', b'blinded_node_id'):
raise Exception('invalid key_type {}'.format(key_type))
key = hmac_oneshot(key_type, msg=secret, digest=hashlib.sha256)
return key
Expand All @@ -168,12 +174,42 @@ def get_shared_secrets_along_route(payment_path_pubkeys: Sequence[bytes],
return hop_shared_secrets


def get_shared_secrets_along_route2(payment_path_pubkeys_plus: Sequence[Union[bytes, Tuple[bytes, bytes]]],
session_key: bytes) -> Sequence[Tuple[bytes, bytes]]:
num_hops = len(payment_path_pubkeys_plus)
hop_shared_secrets = num_hops * [b'']
hop_blinded_node_ids = num_hops * [b'']
ephemeral_key = session_key
payment_path_pubkeys = deepcopy(payment_path_pubkeys_plus)
# compute shared key for each hop
for i in range(0, num_hops):
if isinstance(payment_path_pubkeys[i], tuple):
ephemeral_key = payment_path_pubkeys[i][1]
payment_path_pubkeys[i] = payment_path_pubkeys[i][0]
hop_shared_secrets[i] = get_ecdh(ephemeral_key, payment_path_pubkeys[i])

# blinded node id
# B(i) = HMAC256("blinded_node_id", ss(i)) * N(i)
ss_bni_hmac = get_bolt04_onion_key(b'blinded_node_id', hop_shared_secrets[i])
ss_bni_hmac_int = int.from_bytes(ss_bni_hmac, byteorder="big")
blinded_node_id = ECPubkey(payment_path_pubkeys[i]) * ss_bni_hmac_int
hop_blinded_node_ids[i] = blinded_node_id.get_public_key_bytes()

ephemeral_pubkey = ecc.ECPrivkey(ephemeral_key).get_public_key_bytes()
blinding_factor = sha256(ephemeral_pubkey + hop_shared_secrets[i])
blinding_factor_int = int.from_bytes(blinding_factor, byteorder="big")
ephemeral_key_int = int.from_bytes(ephemeral_key, byteorder="big")
ephemeral_key_int = ephemeral_key_int * blinding_factor_int % ecc.CURVE_ORDER
ephemeral_key = ephemeral_key_int.to_bytes(32, byteorder="big")
return hop_shared_secrets, hop_blinded_node_ids


def new_onion_packet(
payment_path_pubkeys: Sequence[bytes],
session_key: bytes,
hops_data: Sequence[OnionHopsDataSingle],
*,
associated_data: bytes,
associated_data: bytes = b'',
trampoline: bool = False,
) -> OnionPacket:
num_hops = len(payment_path_pubkeys)
Expand Down Expand Up @@ -210,6 +246,40 @@ def new_onion_packet(
hmac=next_hmac)


def new_onion_packet2(
payment_path_pubkeys: Sequence[Union[bytes, tuple]],
session_key: bytes,
hops_data: Sequence[OnionHopsDataSingle],
*,
associated_data: bytes = b'',
trampoline: bool = False,
) -> OnionPacket:
num_hops = len(payment_path_pubkeys)
assert num_hops == len(hops_data)
hop_shared_secrets, blinded_node_ids = get_shared_secrets_along_route2(payment_path_pubkeys, session_key)
# compute routing info and MAC for each hop
for i in range(num_hops):
rho_key = get_bolt04_onion_key(b'rho', hop_shared_secrets[i])
if hops_data[i].tlv_stream_name == 'onionmsg_tlv': # route blinding?
encrypted_data_tlv_fd = io.BytesIO()
OnionWireSerializer.write_tlv_stream(
fd=encrypted_data_tlv_fd,
tlv_stream_name='encrypted_data_tlv',
**hops_data[i].blind_fields)
encrypted_data_tlv_bytes = encrypted_data_tlv_fd.getvalue()
encrypted_recipient_data = chacha20_poly1305_encrypt(key=rho_key, nonce=bytes(12), data=encrypted_data_tlv_bytes)
payload = hops_data[i].payload
payload['encrypted_recipient_data'] = {'encrypted_recipient_data': encrypted_recipient_data}

return new_onion_packet(
payment_path_pubkeys=blinded_node_ids,
session_key=session_key,
hops_data=hops_data,
associated_data=associated_data,
trampoline=trampoline,
)


def calc_hops_data_for_payment(
route: 'LNPaymentRoute',
amount_msat: int, # that final recipient receives
Expand Down
Loading