Skip to content

Commit

Permalink
feat: add coincurve instead of secp256k1 (#46)
Browse files Browse the repository at this point in the history
* feat: add coincurve instead of secp256k1

- changed secp256k1 to coincurve
- add testcase for verifying pubkey
lightning/bolts#1184
- added keep_payee flag for encoding
- added tests for signature model
- added test for invalid signature
  • Loading branch information
dni authored Jul 19, 2024
1 parent ef22967 commit 6dccbd2
Show file tree
Hide file tree
Showing 9 changed files with 566 additions and 469 deletions.
5 changes: 4 additions & 1 deletion bolt11/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" bolt11 CLI """
"""bolt11 CLI"""

import json
import sys
Expand Down Expand Up @@ -52,10 +52,12 @@ def decode(bolt11, ignore_exceptions, strict):
@click.argument("private_key", type=str, default=None, required=False)
@click.argument("ignore_exceptions", type=bool, default=True)
@click.argument("strict", type=bool, default=False)
@click.argument("keep_payee", type=bool, default=False)
def encode(
json_string,
ignore_exceptions: bool = True,
strict: bool = False,
keep_payee: bool = False,
private_key: Optional[str] = None,
):
"""
Expand Down Expand Up @@ -92,6 +94,7 @@ def encode(
private_key,
ignore_exceptions=ignore_exceptions,
strict=strict,
keep_payee=keep_payee,
)
click.echo(encoded)
except Bolt11Exception as exc:
Expand Down
16 changes: 11 additions & 5 deletions bolt11/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ def decode(
timestamp = data_part.read(35).uint

tags = Tags()
payee = None

while data_part.pos != data_part.len:
tag, tagdata, data_part = _pull_tagged(data_part)
data_length = int(len(tagdata or []) / 5)

# MUST skip over unknown fields, OR an f field with unknown version, OR p, h,
# s or n fields that do NOT have data_lengths of 52, 52, 52 or 53, respectively.

if (
tag == TagChar.payment_hash.value
and data_length == 52
Expand Down Expand Up @@ -93,9 +95,10 @@ def decode(
and data_length == 53
and not tags.has(TagChar.payee)
):
payee = trim_to_bytes(tagdata).hex()
tags.add(
TagChar.payee,
trim_to_bytes(tagdata).hex(),
payee,
)
elif (
tag == TagChar.description.value
Expand Down Expand Up @@ -133,19 +136,22 @@ def decode(
elif tag == TagChar.route_hint.value:
tags.add(TagChar.route_hint, RouteHint.from_bitstring(tagdata))

else:
# skip unknown fields
pass

signature = Signature(
signature_data=signature_data,
signing_data=hrp.encode() + data_part.tobytes(),
signing_data=data_part.tobytes(),
hrp=hrp,
)

# A reader MUST check that the `signature` is valid (see the `n` tagged field
# specified below). A reader MUST use the `n` field to validate the signature
# instead of performing signature recovery if a valid `n` field is provided.
payee = tags.get(TagChar.payee)
if payee:
# TODO: research why no test runs this?
try:
signature.verify(payee.data)
signature.verify(payee)
except Exception as exc:
raise Bolt11SignatureVerifyException() from exc
else:
Expand Down
11 changes: 6 additions & 5 deletions bolt11/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def encode(
private_key: Optional[str] = None,
ignore_exceptions: bool = False,
strict: bool = False,
keep_payee: bool = False,
) -> str:
try:
if invoice.description_hash:
Expand All @@ -75,10 +76,8 @@ def encode(
tags += _tagged_bytes(tag.bech32, bytes.fromhex(tag.data))
elif tag.char == TagChar.metadata:
tags += _tagged_bytes(tag.bech32, bytes.fromhex(tag.data))
# TODO: why uncommented?
# payee is not needed, needs more research
# elif tag.char == TagChar.payee:
# tags += _tagged_bytes(tag.bech32, bytes.fromhex(tag.data))
elif tag.char == TagChar.payee and keep_payee:
tags += _tagged_bytes(tag.bech32, bytes.fromhex(tag.data))
elif tag.char == TagChar.features:
tags += _tagged_bytes(tag.bech32, tag.data.data)
elif tag.char == TagChar.fallback:
Expand All @@ -94,7 +93,9 @@ def encode(
data_part = timestamp + tags

if private_key:
invoice.signature = Signature.from_private_key(private_key, hrp, data_part)
invoice.signature = Signature.from_private_key(
hrp=hrp, private_key=private_key, signing_data=data_part.tobytes()
)

if not invoice.signature:
raise Bolt11NoSignatureException()
Expand Down
60 changes: 38 additions & 22 deletions bolt11/models/signature.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,60 @@
from dataclasses import dataclass
from hashlib import sha256
from typing import Optional

from bitstring import Bits
from ecdsa import SECP256k1, VerifyingKey
from ecdsa.util import sigdecode_string
from secp256k1 import PrivateKey
from coincurve import PrivateKey, PublicKey, verify_signature
from coincurve.ecdsa import cdata_to_der, deserialize_recoverable, recoverable_convert


def message(hrp: str, signing_data: bytes) -> bytes:
return bytearray([ord(c) for c in hrp]) + signing_data


@dataclass
class Signature:
"""An invoice signature."""

hrp: str
signing_data: bytes
signature_data: Optional[bytes] = None
signature_data: bytes

@classmethod
def from_signature_data(
cls, hrp: str, signature_data: bytes, signing_data: bytes
) -> "Signature":
return cls(hrp=hrp, signature_data=signature_data, signing_data=signing_data)

@classmethod
def from_private_key(
cls, private_key: str, hrp: str, signing_data: Bits
cls, hrp: str, private_key: str, signing_data: bytes
) -> "Signature":
key = PrivateKey(bytes.fromhex(private_key))
sig = key.ecdsa_sign_recoverable(
bytearray([ord(c) for c in hrp]) + signing_data.tobytes()
)
sig, recid = key.ecdsa_recoverable_serialize(sig)
signature_data = bytes(sig) + bytes([recid])
return cls(signing_data=signing_data.tobytes(), signature_data=signature_data)
key = PrivateKey.from_hex(private_key)
signature_data = key.sign_recoverable(message(hrp, signing_data))
return cls(hrp=hrp, signing_data=signing_data, signature_data=signature_data)

def verify(self, payee: str) -> bool:
key = VerifyingKey.from_string(bytes.fromhex(payee), curve=SECP256k1)
return key.verify(
self.sig, self.signing_data, sha256, sigdecode=sigdecode_string
)
if not self.signature_data:
raise ValueError("No signature data")
if not self.signing_data:
raise ValueError("No signing data")
sig = deserialize_recoverable(self.signature_data)
sig = recoverable_convert(sig)
sig = cdata_to_der(sig)
if not verify_signature(
sig, message(self.hrp, self.signing_data), bytes.fromhex(payee)
):
raise ValueError("Invalid signature")
return True

def recover_public_key(self) -> str:
keys = VerifyingKey.from_public_key_recovery(
self.sig, self.signing_data, SECP256k1, sha256
if not self.signature_data:
raise ValueError("No signature data")
if not self.signing_data:
raise ValueError("No signing data")

key = PublicKey.from_signature_and_message(
self.signature_data, message(self.hrp, self.signing_data)
)
key = keys[self.recovery_flag]
return key.to_string("compressed").hex()
return key.format(compressed=True).hex()

@property
def r(self) -> str:
Expand Down
Loading

0 comments on commit 6dccbd2

Please sign in to comment.