Skip to content

Commit

Permalink
[feat] Check payment tag (lnbits#2522)
Browse files Browse the repository at this point in the history
* feat: check if the payment is made for an extension that the user disabed
  • Loading branch information
motorina0 authored May 24, 2024
1 parent 93965bc commit 7c68a02
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 23 deletions.
5 changes: 5 additions & 0 deletions lnbits/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,8 @@ class BalanceDelta(BaseModel):
@property
def delta_msats(self):
return self.node_balance_msats - self.lnbits_balance_msats


class SimpleStatus(BaseModel):
success: bool
message: str
41 changes: 28 additions & 13 deletions lnbits/core/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@

from lnbits.core.db import db
from lnbits.db import Connection
from lnbits.decorators import WalletTypeInfo, require_admin_key
from lnbits.decorators import (
WalletTypeInfo,
check_user_extension_access,
require_admin_key,
)
from lnbits.helpers import url_for
from lnbits.lnurl import LnurlErrorResponse
from lnbits.lnurl import decode as decode_lnurl
Expand Down Expand Up @@ -300,18 +304,13 @@ class PaymentKwargs(TypedDict):
# do the balance check
wallet = await get_wallet(wallet_id, conn=conn)
assert wallet, "Wallet for balancecheck could not be fetched"
if wallet.balance_msat < 0:
logger.debug("balance is too low, deleting temporary payment")
if (
not internal_checking_id
and wallet.balance_msat > -fee_reserve_total_msat
):
raise PaymentError(
f"You must reserve at least ({round(fee_reserve_total_msat/1000)}"
" sat) to cover potential routing fees.",
status="failed",
)
raise PaymentError("Insufficient balance.", status="failed")
_check_wallet_balance(wallet, fee_reserve_total_msat, internal_checking_id)

if extra and "tag" in extra:
# check if the payment is made for an extension that the user disabled
status = await check_user_extension_access(wallet.user, extra["tag"])
if not status.success:
raise PaymentError(status.message)

if internal_checking_id:
service_fee_msat = service_fee(invoice.amount_msat, internal=True)
Expand Down Expand Up @@ -402,6 +401,22 @@ class PaymentKwargs(TypedDict):
return invoice.payment_hash


def _check_wallet_balance(
wallet: Wallet,
fee_reserve_total_msat: int,
internal_checking_id: Optional[str] = None,
):
if wallet.balance_msat < 0:
logger.debug("balance is too low, deleting temporary payment")
if not internal_checking_id and wallet.balance_msat > -fee_reserve_total_msat:
raise PaymentError(
f"You must reserve at least ({round(fee_reserve_total_msat/1000)}"
" sat) to cover potential routing fees.",
status="failed",
)
raise PaymentError("Insufficient balance.", status="failed")


async def check_wallet_limits(wallet_id, conn, amount_msat):
await check_time_limit_between_transactions(conn, wallet_id)
await check_wallet_daily_withdraw_limit(conn, wallet_id, amount_msat)
Expand Down
29 changes: 19 additions & 10 deletions lnbits/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
get_user_active_extensions_ids,
get_wallet_for_key,
)
from lnbits.core.models import KeyType, User, WalletTypeInfo
from lnbits.core.models import KeyType, SimpleStatus, User, WalletTypeInfo
from lnbits.db import Filter, Filters, TFilterModel
from lnbits.settings import AuthMethods, settings

Expand Down Expand Up @@ -210,27 +210,36 @@ def dependency(
return dependency


async def _check_user_extension_access(user_id: str, current_path: str):
async def check_user_extension_access(user_id: str, ext_id: str) -> SimpleStatus:
"""
Check if the user has access to a particular extension.
Raises HTTP Forbidden if the user is not allowed.
"""
path = current_path.split("/")
ext_id = path[3] if path[1] == "upgrades" else path[1]
if settings.is_admin_extension(ext_id) and not settings.is_admin_user(user_id):
raise HTTPException(
HTTPStatus.FORBIDDEN,
f"User not authorized for extension '{ext_id}'.",
return SimpleStatus(
success=False, message=f"User not authorized for extension '{ext_id}'."
)

if settings.is_extension_id(ext_id):
ext_ids = await get_user_active_extensions_ids(user_id)
if ext_id not in ext_ids:
raise HTTPException(
HTTPStatus.FORBIDDEN,
f"User extension '{ext_id}' not enabled.",
return SimpleStatus(
success=False, message=f"User extension '{ext_id}' not enabled."
)

return SimpleStatus(success=True, message="OK")


async def _check_user_extension_access(user_id: str, current_path: str):
path = current_path.split("/")
ext_id = path[3] if path[1] == "upgrades" else path[1]
status = await check_user_extension_access(user_id, ext_id)
if not status.success:
raise HTTPException(
HTTPStatus.FORBIDDEN,
status.message,
)


async def _get_account_from_token(access_token):
try:
Expand Down

0 comments on commit 7c68a02

Please sign in to comment.