diff --git a/.gitignore b/.gitignore index 916190a9..171f4955 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ nostr-wallet-connect nwc.db .breez .data +.idea frontend/dist frontend/node_modules @@ -21,6 +22,7 @@ build/bin # generated by platform-specific files frontend/src/utils/request.ts frontend/src/utils/openLink.ts +frontend/.yarn # generated by rust go bindings for local development glalby diff --git a/db/queries/get_budget_usage.go b/db/queries/get_budget_usage.go index 78abc3b1..20d72b76 100644 --- a/db/queries/get_budget_usage.go +++ b/db/queries/get_budget_usage.go @@ -42,3 +42,26 @@ func getStartOfBudget(budget_type string) time.Time { return time.Time{} } } + +func GetBudgetRenewsAt(budgetRenewal string) *uint64 { + budgetStart := getStartOfBudget(budgetRenewal) + switch budgetRenewal { + case constants.BUDGET_RENEWAL_DAILY: + renewal := uint64(budgetStart.AddDate(0, 0, 1).Unix()) + return &renewal + case constants.BUDGET_RENEWAL_WEEKLY: + renewal := uint64(budgetStart.AddDate(0, 0, 7).Unix()) + return &renewal + + case constants.BUDGET_RENEWAL_MONTHLY: + renewal := uint64(budgetStart.AddDate(0, 1, 0).Unix()) + return &renewal + + case constants.BUDGET_RENEWAL_YEARLY: + renewal := uint64(budgetStart.AddDate(1, 0, 0).Unix()) + return &renewal + + default: //"never" + return nil + } +} diff --git a/frontend/src/types.ts b/frontend/src/types.ts index 2d4a5e01..35c1ee5c 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -21,6 +21,7 @@ export type BackendType = export type Nip47RequestMethod = | "get_info" | "get_balance" + | "get_budget" | "make_invoice" | "pay_invoice" | "pay_keysend" diff --git a/lnclient/breez/breez.go b/lnclient/breez/breez.go index 2cea7f2b..489673a6 100644 --- a/lnclient/breez/breez.go +++ b/lnclient/breez/breez.go @@ -478,7 +478,7 @@ func (bs *BreezService) DisconnectPeer(ctx context.Context, peerId string) error } func (bs *BreezService) GetSupportedNIP47Methods() []string { - return []string{"pay_invoice" /*"pay_keysend",*/, "get_balance", "get_info", "make_invoice", "lookup_invoice", "list_transactions", "multi_pay_invoice", "multi_pay_keysend", "sign_message"} + return []string{"pay_invoice" /*"pay_keysend",*/, "get_balance", "get_budget", "get_info", "make_invoice", "lookup_invoice", "list_transactions", "multi_pay_invoice", "multi_pay_keysend", "sign_message"} } func (bs *BreezService) GetSupportedNIP47NotificationTypes() []string { diff --git a/lnclient/cashu/cashu.go b/lnclient/cashu/cashu.go index 45e352dc..f7e41c23 100644 --- a/lnclient/cashu/cashu.go +++ b/lnclient/cashu/cashu.go @@ -358,7 +358,7 @@ func (cs *CashuService) checkInvoice(cashuInvoice *storage.Invoice) { } func (cs *CashuService) GetSupportedNIP47Methods() []string { - return []string{"pay_invoice", "get_balance", "get_info", "make_invoice", "lookup_invoice", "list_transactions", "multi_pay_invoice"} + return []string{"pay_invoice", "get_balance", "get_budget", "get_info", "make_invoice", "lookup_invoice", "list_transactions", "multi_pay_invoice"} } func (cs *CashuService) GetSupportedNIP47NotificationTypes() []string { diff --git a/lnclient/greenlight/greenlight.go b/lnclient/greenlight/greenlight.go index 432313a7..8b5f6cc3 100644 --- a/lnclient/greenlight/greenlight.go +++ b/lnclient/greenlight/greenlight.go @@ -674,7 +674,7 @@ func (gs *GreenlightService) DisconnectPeer(ctx context.Context, peerId string) } func (gs *GreenlightService) GetSupportedNIP47Methods() []string { - return []string{"pay_invoice" /*"pay_keysend",*/, "get_balance", "get_info", "make_invoice", "lookup_invoice", "list_transactions", "multi_pay_invoice", "multi_pay_keysend", "sign_message"} + return []string{"pay_invoice" /*"pay_keysend",*/, "get_balance", "get_budget", "get_info", "make_invoice", "lookup_invoice", "list_transactions", "multi_pay_invoice", "multi_pay_keysend", "sign_message"} } func (gs *GreenlightService) GetSupportedNIP47NotificationTypes() []string { diff --git a/lnclient/ldk/ldk.go b/lnclient/ldk/ldk.go index 3886dbcf..720c6bd5 100644 --- a/lnclient/ldk/ldk.go +++ b/lnclient/ldk/ldk.go @@ -1631,7 +1631,7 @@ func (ls *LDKService) UpdateLastWalletSyncRequest() { } func (ls *LDKService) GetSupportedNIP47Methods() []string { - return []string{"pay_invoice", "pay_keysend", "get_balance", "get_info", "make_invoice", "lookup_invoice", "list_transactions", "multi_pay_invoice", "multi_pay_keysend", "sign_message"} + return []string{"pay_invoice", "pay_keysend", "get_balance", "get_budget", "get_info", "make_invoice", "lookup_invoice", "list_transactions", "multi_pay_invoice", "multi_pay_keysend", "sign_message"} } func (ls *LDKService) GetSupportedNIP47NotificationTypes() []string { diff --git a/lnclient/lnd/lnd.go b/lnclient/lnd/lnd.go index fdeb0dc3..c0eb4d01 100644 --- a/lnclient/lnd/lnd.go +++ b/lnclient/lnd/lnd.go @@ -1084,7 +1084,7 @@ func (svc *LNDService) DisconnectPeer(ctx context.Context, peerId string) error func (svc *LNDService) GetSupportedNIP47Methods() []string { return []string{ - "pay_invoice", "pay_keysend", "get_balance", "get_info", "make_invoice", "lookup_invoice", "list_transactions", "multi_pay_invoice", "multi_pay_keysend", "sign_message", + "pay_invoice", "pay_keysend", "get_balance", "get_budget", "get_info", "make_invoice", "lookup_invoice", "list_transactions", "multi_pay_invoice", "multi_pay_keysend", "sign_message", } } diff --git a/lnclient/phoenixd/phoenixd.go b/lnclient/phoenixd/phoenixd.go index f76669bd..e8c069bb 100644 --- a/lnclient/phoenixd/phoenixd.go +++ b/lnclient/phoenixd/phoenixd.go @@ -525,7 +525,7 @@ func (svc *PhoenixService) UpdateChannel(ctx context.Context, updateChannelReque } func (svc *PhoenixService) GetSupportedNIP47Methods() []string { - return []string{"pay_invoice", "get_balance", "get_info", "make_invoice", "lookup_invoice", "list_transactions", "multi_pay_invoice"} + return []string{"pay_invoice", "get_balance", "get_budget", "get_info", "make_invoice", "lookup_invoice", "list_transactions", "multi_pay_invoice"} } func (svc *PhoenixService) GetSupportedNIP47NotificationTypes() []string { diff --git a/nip47/controllers/get_budget_controller.go b/nip47/controllers/get_budget_controller.go new file mode 100644 index 00000000..70fc746e --- /dev/null +++ b/nip47/controllers/get_budget_controller.go @@ -0,0 +1,52 @@ +package controllers + +import ( + "context" + + "github.com/getAlby/hub/db/queries" + "github.com/nbd-wtf/go-nostr" + + "github.com/getAlby/hub/db" + "github.com/getAlby/hub/logger" + "github.com/getAlby/hub/nip47/models" + "github.com/sirupsen/logrus" +) + +type getBudgetResponse struct { + UsedBudget uint64 `json:"used_budget"` + TotalBudget uint64 `json:"total_budget"` + RenewsAt *uint64 `json:"renews_at,omitempty"` + RenewalPeriod string `json:"renewal_period"` +} + +func (controller *nip47Controller) HandleGetBudgetEvent(ctx context.Context, nip47Request *models.Request, requestEventId uint, app *db.App, publishResponse publishFunc) { + + logger.Logger.WithFields(logrus.Fields{ + "request_event_id": requestEventId, + }).Debug("Getting budget") + + appPermission := db.AppPermission{} + controller.db.Where("app_id = ? AND scope = ?", app.ID, models.PAY_INVOICE_METHOD).First(&appPermission) + + maxAmount := appPermission.MaxAmountSat + if maxAmount == 0 { + publishResponse(&models.Response{ + ResultType: nip47Request.Method, + Result: struct{}{}, + }, nostr.Tags{}) + return + } + + usedBudget := queries.GetBudgetUsageSat(controller.db, &appPermission) + responsePayload := &getBudgetResponse{ + TotalBudget: uint64(maxAmount * 1000), + UsedBudget: usedBudget * 1000, + RenewalPeriod: appPermission.BudgetRenewal, + RenewsAt: queries.GetBudgetRenewsAt(appPermission.BudgetRenewal), + } + + publishResponse(&models.Response{ + ResultType: nip47Request.Method, + Result: responsePayload, + }, nostr.Tags{}) +} diff --git a/nip47/controllers/get_budget_controller_test.go b/nip47/controllers/get_budget_controller_test.go new file mode 100644 index 00000000..e007c3b2 --- /dev/null +++ b/nip47/controllers/get_budget_controller_test.go @@ -0,0 +1,249 @@ +package controllers + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/nbd-wtf/go-nostr" + "github.com/stretchr/testify/assert" + + "github.com/getAlby/hub/constants" + "github.com/getAlby/hub/db" + "github.com/getAlby/hub/nip47/models" + "github.com/getAlby/hub/nip47/permissions" + "github.com/getAlby/hub/tests" + "github.com/getAlby/hub/transactions" +) + +const nip47GetBudgetJson = ` +{ + "method": "get_budget" +} +` + +func TestHandleGetBudgetEvent_NoRenewal(t *testing.T) { + ctx := context.TODO() + defer tests.RemoveTestService() + svc, err := tests.CreateTestService() + assert.NoError(t, err) + + nip47Request := &models.Request{} + err = json.Unmarshal([]byte(nip47GetBudgetJson), nip47Request) + assert.NoError(t, err) + + app, _, err := tests.CreateApp(svc) + assert.NoError(t, err) + + appPermission := &db.AppPermission{ + AppId: app.ID, + App: *app, + Scope: constants.PAY_INVOICE_SCOPE, + MaxAmountSat: 400, + BudgetRenewal: constants.BUDGET_RENEWAL_NEVER, + } + err = svc.DB.Create(appPermission).Error + assert.NoError(t, err) + + dbRequestEvent := &db.RequestEvent{} + err = svc.DB.Create(&dbRequestEvent).Error + assert.NoError(t, err) + + var publishedResponse *models.Response + + publishResponse := func(response *models.Response, tags nostr.Tags) { + publishedResponse = response + } + + permissionsSvc := permissions.NewPermissionsService(svc.DB, svc.EventPublisher) + transactionsSvc := transactions.NewTransactionsService(svc.DB, svc.EventPublisher) + NewNip47Controller(svc.LNClient, svc.DB, svc.EventPublisher, permissionsSvc, transactionsSvc). + HandleGetBudgetEvent(ctx, nip47Request, dbRequestEvent.ID, app, publishResponse) + + assert.Equal(t, uint64(400000), publishedResponse.Result.(*getBudgetResponse).TotalBudget) + assert.Equal(t, uint64(0), publishedResponse.Result.(*getBudgetResponse).UsedBudget) + assert.Nil(t, publishedResponse.Result.(*getBudgetResponse).RenewsAt) + assert.Equal(t, constants.BUDGET_RENEWAL_NEVER, publishedResponse.Result.(*getBudgetResponse).RenewalPeriod) + assert.Nil(t, publishedResponse.Error) +} + +func TestHandleGetBudgetEvent_NoneUsed(t *testing.T) { + ctx := context.TODO() + defer tests.RemoveTestService() + svc, err := tests.CreateTestService() + assert.NoError(t, err) + + nip47Request := &models.Request{} + err = json.Unmarshal([]byte(nip47GetBudgetJson), nip47Request) + assert.NoError(t, err) + + app, _, err := tests.CreateApp(svc) + assert.NoError(t, err) + now := time.Now() + + appPermission := &db.AppPermission{ + AppId: app.ID, + App: *app, + Scope: constants.PAY_INVOICE_SCOPE, + MaxAmountSat: 400, + BudgetRenewal: constants.BUDGET_RENEWAL_MONTHLY, + } + err = svc.DB.Create(appPermission).Error + assert.NoError(t, err) + + dbRequestEvent := &db.RequestEvent{} + err = svc.DB.Create(&dbRequestEvent).Error + assert.NoError(t, err) + + var publishedResponse *models.Response + + publishResponse := func(response *models.Response, tags nostr.Tags) { + publishedResponse = response + } + + permissionsSvc := permissions.NewPermissionsService(svc.DB, svc.EventPublisher) + transactionsSvc := transactions.NewTransactionsService(svc.DB, svc.EventPublisher) + NewNip47Controller(svc.LNClient, svc.DB, svc.EventPublisher, permissionsSvc, transactionsSvc). + HandleGetBudgetEvent(ctx, nip47Request, dbRequestEvent.ID, app, publishResponse) + + assert.Equal(t, uint64(400000), publishedResponse.Result.(*getBudgetResponse).TotalBudget) + assert.Equal(t, uint64(0), publishedResponse.Result.(*getBudgetResponse).UsedBudget) + renewsAt := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location()).AddDate(0, 1, 0).Unix() + assert.Equal(t, uint64(renewsAt), *publishedResponse.Result.(*getBudgetResponse).RenewsAt) + assert.Equal(t, constants.BUDGET_RENEWAL_MONTHLY, publishedResponse.Result.(*getBudgetResponse).RenewalPeriod) + assert.Nil(t, publishedResponse.Error) +} + +func TestHandleGetBudgetEvent_HalfUsed(t *testing.T) { + ctx := context.TODO() + defer tests.RemoveTestService() + svc, err := tests.CreateTestService() + assert.NoError(t, err) + + nip47Request := &models.Request{} + err = json.Unmarshal([]byte(nip47GetBudgetJson), nip47Request) + assert.NoError(t, err) + + app, _, err := tests.CreateApp(svc) + assert.NoError(t, err) + now := time.Now() + + appPermission := &db.AppPermission{ + AppId: app.ID, + App: *app, + Scope: constants.PAY_INVOICE_SCOPE, + MaxAmountSat: 400, + BudgetRenewal: constants.BUDGET_RENEWAL_MONTHLY, + } + err = svc.DB.Create(appPermission).Error + assert.NoError(t, err) + + svc.DB.Create(&db.Transaction{ + AppId: &app.ID, + State: constants.TRANSACTION_STATE_SETTLED, + Type: constants.TRANSACTION_TYPE_OUTGOING, + AmountMsat: 200000, + }) + + dbRequestEvent := &db.RequestEvent{} + err = svc.DB.Create(&dbRequestEvent).Error + assert.NoError(t, err) + + var publishedResponse *models.Response + + publishResponse := func(response *models.Response, tags nostr.Tags) { + publishedResponse = response + } + + permissionsSvc := permissions.NewPermissionsService(svc.DB, svc.EventPublisher) + transactionsSvc := transactions.NewTransactionsService(svc.DB, svc.EventPublisher) + NewNip47Controller(svc.LNClient, svc.DB, svc.EventPublisher, permissionsSvc, transactionsSvc). + HandleGetBudgetEvent(ctx, nip47Request, dbRequestEvent.ID, app, publishResponse) + + assert.Equal(t, uint64(400000), publishedResponse.Result.(*getBudgetResponse).TotalBudget) + assert.Equal(t, uint64(200000), publishedResponse.Result.(*getBudgetResponse).UsedBudget) + renewsAt := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location()).AddDate(0, 1, 0).Unix() + assert.Equal(t, uint64(renewsAt), *publishedResponse.Result.(*getBudgetResponse).RenewsAt) + assert.Equal(t, constants.BUDGET_RENEWAL_MONTHLY, publishedResponse.Result.(*getBudgetResponse).RenewalPeriod) + assert.Nil(t, publishedResponse.Error) +} + +func TestHandleGetBudgetEvent_NoBudget(t *testing.T) { + ctx := context.TODO() + defer tests.RemoveTestService() + svc, err := tests.CreateTestService() + assert.NoError(t, err) + + nip47Request := &models.Request{} + err = json.Unmarshal([]byte(nip47GetBudgetJson), nip47Request) + assert.NoError(t, err) + + app, _, err := tests.CreateApp(svc) + assert.NoError(t, err) + + appPermission := &db.AppPermission{ + AppId: app.ID, + App: *app, + Scope: constants.PAY_INVOICE_SCOPE, + } + err = svc.DB.Create(appPermission).Error + assert.NoError(t, err) + + svc.DB.Create(&db.Transaction{ + AppId: &app.ID, + State: constants.TRANSACTION_STATE_SETTLED, + Type: constants.TRANSACTION_TYPE_OUTGOING, + AmountMsat: 200000, + }) + + dbRequestEvent := &db.RequestEvent{} + err = svc.DB.Create(&dbRequestEvent).Error + assert.NoError(t, err) + + var publishedResponse *models.Response + + publishResponse := func(response *models.Response, tags nostr.Tags) { + publishedResponse = response + } + + permissionsSvc := permissions.NewPermissionsService(svc.DB, svc.EventPublisher) + transactionsSvc := transactions.NewTransactionsService(svc.DB, svc.EventPublisher) + NewNip47Controller(svc.LNClient, svc.DB, svc.EventPublisher, permissionsSvc, transactionsSvc). + HandleGetBudgetEvent(ctx, nip47Request, dbRequestEvent.ID, app, publishResponse) + + assert.Equal(t, struct{}{}, publishedResponse.Result) + assert.Nil(t, publishedResponse.Error) +} + +func TestHandleGetBudgetEvent_NoPayInvoicePermission(t *testing.T) { + ctx := context.TODO() + defer tests.RemoveTestService() + svc, err := tests.CreateTestService() + assert.NoError(t, err) + + nip47Request := &models.Request{} + err = json.Unmarshal([]byte(nip47GetBudgetJson), nip47Request) + assert.NoError(t, err) + + app, _, err := tests.CreateApp(svc) + assert.NoError(t, err) + + dbRequestEvent := &db.RequestEvent{} + err = svc.DB.Create(&dbRequestEvent).Error + assert.NoError(t, err) + + var publishedResponse *models.Response + + publishResponse := func(response *models.Response, tags nostr.Tags) { + publishedResponse = response + } + + permissionsSvc := permissions.NewPermissionsService(svc.DB, svc.EventPublisher) + transactionsSvc := transactions.NewTransactionsService(svc.DB, svc.EventPublisher) + NewNip47Controller(svc.LNClient, svc.DB, svc.EventPublisher, permissionsSvc, transactionsSvc). + HandleGetBudgetEvent(ctx, nip47Request, dbRequestEvent.ID, app, publishResponse) + + assert.Equal(t, struct{}{}, publishedResponse.Result) + assert.Nil(t, publishedResponse.Error) +} diff --git a/nip47/controllers/get_info_controller.go b/nip47/controllers/get_info_controller.go index eec4c945..1145cdd3 100644 --- a/nip47/controllers/get_info_controller.go +++ b/nip47/controllers/get_info_controller.go @@ -34,6 +34,8 @@ func (controller *nip47Controller) HandleGetInfoEvent(ctx context.Context, nip47 } // basic permissions check + // this is inconsistent with other methods. Ideally we move fetching node info to a separate method, + // so that get_info does not require its own scope. This would require a change in the NIP-47 spec. hasPermission, _, _ := controller.permissionsService.HasPermission(app, constants.GET_INFO_SCOPE) if hasPermission { logger.Logger.WithFields(logrus.Fields{ diff --git a/nip47/controllers/get_info_controller_test.go b/nip47/controllers/get_info_controller_test.go index 1d7e69e5..d8498c99 100644 --- a/nip47/controllers/get_info_controller_test.go +++ b/nip47/controllers/get_info_controller_test.go @@ -66,7 +66,8 @@ func TestHandleGetInfoEvent_NoPermission(t *testing.T) { assert.Empty(t, nodeInfo.Network) assert.Empty(t, nodeInfo.BlockHeight) assert.Empty(t, nodeInfo.BlockHash) - assert.Equal(t, []string{"get_balance"}, nodeInfo.Methods) + // get_info method is always granted, but does not return pubkey + assert.Contains(t, nodeInfo.Methods, models.GET_INFO_METHOD) assert.Equal(t, []string{}, nodeInfo.Notifications) } @@ -114,7 +115,7 @@ func TestHandleGetInfoEvent_WithPermission(t *testing.T) { assert.Equal(t, tests.MockNodeInfo.Network, nodeInfo.Network) assert.Equal(t, tests.MockNodeInfo.BlockHeight, nodeInfo.BlockHeight) assert.Equal(t, tests.MockNodeInfo.BlockHash, nodeInfo.BlockHash) - assert.Equal(t, []string{"get_info"}, nodeInfo.Methods) + assert.Contains(t, nodeInfo.Methods, "get_info") assert.Equal(t, []string{}, nodeInfo.Notifications) } @@ -170,6 +171,6 @@ func TestHandleGetInfoEvent_WithNotifications(t *testing.T) { assert.Equal(t, tests.MockNodeInfo.Network, nodeInfo.Network) assert.Equal(t, tests.MockNodeInfo.BlockHeight, nodeInfo.BlockHeight) assert.Equal(t, tests.MockNodeInfo.BlockHash, nodeInfo.BlockHash) - assert.Equal(t, []string{"get_info"}, nodeInfo.Methods) + assert.Contains(t, nodeInfo.Methods, "get_info") assert.Equal(t, []string{"payment_received", "payment_sent"}, nodeInfo.Notifications) } diff --git a/nip47/event_handler.go b/nip47/event_handler.go index abd66897..aeb64143 100644 --- a/nip47/event_handler.go +++ b/nip47/event_handler.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "slices" "time" "github.com/getAlby/hub/constants" @@ -269,7 +270,7 @@ func (svc *nip47Service) HandleEvent(ctx context.Context, relay nostrmodels.Rela "params": nip47Request.Params, }).Debug("Handling NIP-47 request") - if nip47Request.Method != models.GET_INFO_METHOD { + if !slices.Contains(permissions.GetAlwaysGrantedMethods(), nip47Request.Method) { scope, err := permissions.RequestMethodToScope(nip47Request.Method) if err != nil { publishResponse(&models.Response{ @@ -330,6 +331,9 @@ func (svc *nip47Service) HandleEvent(ctx context.Context, relay nostrmodels.Rela case models.GET_BALANCE_METHOD: controller. HandleGetBalanceEvent(ctx, nip47Request, requestEvent.ID, &app, publishResponse) + case models.GET_BUDGET_METHOD: + controller. + HandleGetBudgetEvent(ctx, nip47Request, requestEvent.ID, &app, publishResponse) case models.MAKE_INVOICE_METHOD: controller. HandleMakeInvoiceEvent(ctx, nip47Request, requestEvent.ID, app.ID, publishResponse) diff --git a/nip47/event_handler_test.go b/nip47/event_handler_test.go index 114439d4..ed3cef55 100644 --- a/nip47/event_handler_test.go +++ b/nip47/event_handler_test.go @@ -3,11 +3,13 @@ package nip47 import ( "context" "encoding/json" + "slices" "testing" "github.com/getAlby/hub/constants" "github.com/getAlby/hub/db" "github.com/getAlby/hub/nip47/models" + "github.com/getAlby/hub/nip47/permissions" "github.com/getAlby/hub/tests" "github.com/nbd-wtf/go-nostr" "github.com/nbd-wtf/go-nostr/nip04" @@ -124,13 +126,23 @@ func TestHandleResponse_WithPermission(t *testing.T) { decrypted, err := nip04.Decrypt(relay.PublishedEvent.Content, ss) assert.NoError(t, err) - unmarshalledResponse := models.Response{} + type getInfoResult struct { + Methods []string `json:"methods"` + } + + type getInfoResponseWrapper struct { + models.Response + Result getInfoResult `json:"result"` + } + + unmarshalledResponse := getInfoResponseWrapper{} err = json.Unmarshal([]byte(decrypted), &unmarshalledResponse) assert.NoError(t, err) assert.Nil(t, unmarshalledResponse.Error) assert.Equal(t, models.GET_INFO_METHOD, unmarshalledResponse.ResultType) - assert.Equal(t, []interface{}{"get_balance"}, unmarshalledResponse.Result.(map[string]interface{})["methods"]) + expectedMethods := slices.Concat([]string{constants.GET_BALANCE_SCOPE}, permissions.GetAlwaysGrantedMethods()) + assert.Equal(t, expectedMethods, unmarshalledResponse.Result.Methods) } func TestHandleResponse_DuplicateRequest(t *testing.T) { diff --git a/nip47/models/models.go b/nip47/models/models.go index 77491568..648cd921 100644 --- a/nip47/models/models.go +++ b/nip47/models/models.go @@ -13,6 +13,7 @@ const ( // request methods PAY_INVOICE_METHOD = "pay_invoice" GET_BALANCE_METHOD = "get_balance" + GET_BUDGET_METHOD = "get_budget" GET_INFO_METHOD = "get_info" MAKE_INVOICE_METHOD = "make_invoice" LOOKUP_INVOICE_METHOD = "lookup_invoice" diff --git a/nip47/permissions/permissions.go b/nip47/permissions/permissions.go index 00d57eb8..b3f78ddc 100644 --- a/nip47/permissions/permissions.go +++ b/nip47/permissions/permissions.go @@ -36,7 +36,6 @@ func NewPermissionsService(db *gorm.DB, eventPublisher events.EventPublisher) *p } func (svc *permissionsService) HasPermission(app *db.App, scope string) (result bool, code string, message string) { - appPermission := db.AppPermission{} findPermissionResult := svc.db.Limit(1).Find(&appPermission, &db.AppPermission{ AppId: app.ID, @@ -71,6 +70,12 @@ func (svc *permissionsService) GetPermittedMethods(app *db.App, lnClient lnclien requestMethods := scopesToRequestMethods(scopes) + for _, method := range GetAlwaysGrantedMethods() { + if !slices.Contains(requestMethods, method) { + requestMethods = append(requestMethods, method) + } + } + // only return methods supported by the lnClient lnClientSupportedMethods := lnClient.GetSupportedNIP47Methods() requestMethods = utils.Filter(requestMethods, func(requestMethod string) bool { @@ -86,11 +91,8 @@ func (svc *permissionsService) PermitsNotifications(app *db.App) bool { AppId: app.ID, Scope: constants.NOTIFICATIONS_SCOPE, }).Error - if err != nil { - return false - } - return true + return err == nil } func scopesToRequestMethods(scopes []string) []string { @@ -131,7 +133,7 @@ func RequestMethodsToScopes(requestMethods []string) ([]string, error) { if err != nil { return nil, err } - if !slices.Contains(scopes, scope) { + if scope != "" && !slices.Contains(scopes, scope) { scopes = append(scopes, scope) } } @@ -144,6 +146,8 @@ func RequestMethodToScope(requestMethod string) (string, error) { return constants.PAY_INVOICE_SCOPE, nil case models.GET_BALANCE_METHOD: return constants.GET_BALANCE_SCOPE, nil + case models.GET_BUDGET_METHOD: + return "", nil case models.GET_INFO_METHOD: return constants.GET_INFO_SCOPE, nil case models.MAKE_INVOICE_METHOD: @@ -171,3 +175,7 @@ func AllScopes() []string { constants.NOTIFICATIONS_SCOPE, } } + +func GetAlwaysGrantedMethods() []string { + return []string{models.GET_INFO_METHOD, models.GET_BUDGET_METHOD} +} diff --git a/nip47/permissions/permissions_test.go b/nip47/permissions/permissions_test.go index d108cfd2..56e72131 100644 --- a/nip47/permissions/permissions_test.go +++ b/nip47/permissions/permissions_test.go @@ -68,7 +68,7 @@ func TestHasPermission_Expired(t *testing.T) { appPermission := &db.AppPermission{ AppId: app.ID, App: *app, - Scope: models.PAY_INVOICE_METHOD, + Scope: constants.PAY_INVOICE_SCOPE, MaxAmountSat: 10, BudgetRenewal: budgetRenewal, ExpiresAt: &expiresAt, @@ -96,7 +96,7 @@ func TestHasPermission_OK(t *testing.T) { appPermission := &db.AppPermission{ AppId: app.ID, App: *app, - Scope: models.PAY_INVOICE_METHOD, + Scope: constants.PAY_INVOICE_SCOPE, MaxAmountSat: 10, BudgetRenewal: budgetRenewal, ExpiresAt: &expiresAt, @@ -105,8 +105,77 @@ func TestHasPermission_OK(t *testing.T) { assert.NoError(t, err) permissionsSvc := NewPermissionsService(svc.DB, svc.EventPublisher) - result, code, message := permissionsSvc.HasPermission(app, models.PAY_INVOICE_METHOD) + result, code, message := permissionsSvc.HasPermission(app, constants.PAY_INVOICE_SCOPE) assert.True(t, result) assert.Empty(t, code) assert.Empty(t, message) } + +func TestRequestMethodToScope_GetBudget(t *testing.T) { + defer tests.RemoveTestService() + _, err := tests.CreateTestService() + assert.NoError(t, err) + + scope, err := RequestMethodToScope(models.GET_BUDGET_METHOD) + assert.Nil(t, err) + assert.Equal(t, "", scope) +} + +func TestRequestMethodsToScopes_GetBudget(t *testing.T) { + defer tests.RemoveTestService() + _, err := tests.CreateTestService() + assert.NoError(t, err) + + scopes, err := RequestMethodsToScopes([]string{models.GET_BUDGET_METHOD}) + assert.NoError(t, err) + assert.Equal(t, []string{}, scopes) +} + +func TestRequestMethodToScope_GetInfo(t *testing.T) { + scope, err := RequestMethodToScope(models.GET_INFO_METHOD) + assert.NoError(t, err) + assert.Equal(t, constants.GET_INFO_SCOPE, scope) +} + +func TestRequestMethodsToScopes_GetInfo(t *testing.T) { + scopes, err := RequestMethodsToScopes([]string{models.GET_INFO_METHOD}) + assert.NoError(t, err) + assert.Equal(t, []string{constants.GET_INFO_SCOPE}, scopes) +} + +func TestGetPermittedMethods_AlwaysGranted(t *testing.T) { + defer tests.RemoveTestService() + svc, err := tests.CreateTestService() + assert.NoError(t, err) + + app, _, err := tests.CreateApp(svc) + assert.NoError(t, err) + + permissionsSvc := NewPermissionsService(svc.DB, svc.EventPublisher) + result := permissionsSvc.GetPermittedMethods(app, svc.LNClient) + assert.Equal(t, GetAlwaysGrantedMethods(), result) +} + +func TestGetPermittedMethods_PayInvoiceScopeGivesAllPaymentMethods(t *testing.T) { + defer tests.RemoveTestService() + svc, err := tests.CreateTestService() + assert.NoError(t, err) + + app, _, err := tests.CreateApp(svc) + assert.NoError(t, err) + + appPermission := &db.AppPermission{ + AppId: app.ID, + App: *app, + Scope: constants.PAY_INVOICE_SCOPE, + } + err = svc.DB.Create(appPermission).Error + assert.NoError(t, err) + + permissionsSvc := NewPermissionsService(svc.DB, svc.EventPublisher) + result := permissionsSvc.GetPermittedMethods(app, svc.LNClient) + assert.Contains(t, result, models.PAY_INVOICE_METHOD) + assert.Contains(t, result, models.PAY_KEYSEND_METHOD) + assert.Contains(t, result, models.MULTI_PAY_INVOICE_METHOD) + assert.Contains(t, result, models.MULTI_PAY_KEYSEND_METHOD) +} diff --git a/tests/mock_ln_client.go b/tests/mock_ln_client.go index 2c314f41..7e8debfb 100644 --- a/tests/mock_ln_client.go +++ b/tests/mock_ln_client.go @@ -182,7 +182,7 @@ func (mln *MockLn) UpdateChannel(ctx context.Context, updateChannelRequest *lncl } func (mln *MockLn) GetSupportedNIP47Methods() []string { - return []string{"pay_invoice", "pay_keysend", "get_balance", "get_info", "make_invoice", "lookup_invoice", "list_transactions", "multi_pay_invoice", "multi_pay_keysend", "sign_message"} + return []string{"pay_invoice", "pay_keysend", "get_balance", "get_budget", "get_info", "make_invoice", "lookup_invoice", "list_transactions", "multi_pay_invoice", "multi_pay_keysend", "sign_message"} } func (mln *MockLn) GetSupportedNIP47NotificationTypes() []string { if mln.SupportedNotificationTypes != nil {