Skip to content

Commit

Permalink
fix: add isolation level in transaction repository
Browse files Browse the repository at this point in the history
Currently there were chances to have race conditions while writing
to transaction repository. I have added a test to verify it doesn't
happen. Database is using repeatable read as the isolation level
to avoid overlapping transactions.

Signed-off-by: Kush Sharma <[email protected]>
  • Loading branch information
kushsharma committed Nov 29, 2024
1 parent 2968371 commit c1ec2e0
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 112 deletions.
8 changes: 6 additions & 2 deletions internal/api/v1beta1/org.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package v1beta1
import (
"context"

"github.com/raystack/frontier/core/serviceuser"

"github.com/raystack/frontier/core/authenticate"

"go.uber.org/zap"
Expand Down Expand Up @@ -318,13 +320,15 @@ func (h Handler) ListOrganizationServiceUsers(ctx context.Context, request *fron
}
}

users, err := h.serviceUserService.ListByOrg(ctx, orgResp.ID)
usersList, err := h.serviceUserService.List(ctx, serviceuser.Filter{
OrgID: orgResp.ID,
})
if err != nil {
return nil, err
}

var usersPB []*frontierv1beta1.ServiceUser
for _, rel := range users {
for _, rel := range usersList {
u, err := transformServiceUserToPB(rel)
if err != nil {
return nil, err
Expand Down
234 changes: 126 additions & 108 deletions internal/store/postgres/billing_transactions_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"strings"
"time"

"github.com/raystack/frontier/billing/customer"
"github.com/raystack/frontier/internal/bootstrap/schema"

"github.com/jackc/pgconn"
Expand Down Expand Up @@ -81,122 +81,88 @@ func NewBillingTransactionRepository(dbc *db.Client) *BillingTransactionReposito
}
}

func (r BillingTransactionRepository) CreateEntry(ctx context.Context, debitEntry credit.Transaction,
creditEntry credit.Transaction) ([]credit.Transaction, error) {
var customerAcc customer.Customer
var err error
if debitEntry.CustomerID != schema.PlatformOrgID.String() {
// only fetch if it's a customer debit entry
customerAcc, err = r.customerRepo.GetByID(ctx, debitEntry.CustomerID)
if err != nil {
return nil, fmt.Errorf("failed to get customer account: %w", err)
var (
maxRetries = 5
// Error codes from https://www.postgresql.org/docs/current/errcodes-appendix.html
serializationFailureCode = "40001"
deadlockDetectedCode = "40P01"
)

func (r BillingTransactionRepository) withRetry(ctx context.Context, fn func() error) error {
var lastErr error
for i := 0; i < maxRetries && ctx.Err() == nil; i++ {
err := fn()
if err == nil {
return nil
}
}

if debitEntry.Metadata == nil {
debitEntry.Metadata = make(map[string]any)
}
debitMetadata, err := json.Marshal(debitEntry.Metadata)
if err != nil {
return nil, err
}
debitRecord := goqu.Record{
"account_id": debitEntry.CustomerID,
"description": debitEntry.Description,
"type": debitEntry.Type,
"source": debitEntry.Source,
"amount": debitEntry.Amount,
"user_id": debitEntry.UserID,
"metadata": debitMetadata,
"created_at": goqu.L("now()"),
"updated_at": goqu.L("now()"),
}
if debitEntry.ID != "" {
debitRecord["id"] = debitEntry.ID
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) {
// Retry on serialization failures or deadlocks
if pqErr.Code == serializationFailureCode || pqErr.Code == deadlockDetectedCode {
lastErr = err
// Exponential backoff with jitter
backoff := time.Duration(1<<uint(i)) * 100 * time.Millisecond
jitter := time.Duration(rand.Int63n(int64(backoff / 2)))
time.Sleep(backoff + jitter)
continue
}
}
return err // Return immediately for other errors
}
return fmt.Errorf("max retries exceeded: %w", lastErr)
}

if creditEntry.Metadata == nil {
creditEntry.Metadata = make(map[string]any)
}
creditMetadata, err := json.Marshal(creditEntry.Metadata)
if err != nil {
return nil, err
}
creditRecord := goqu.Record{
"account_id": creditEntry.CustomerID,
"description": creditEntry.Description,
"type": creditEntry.Type,
"source": creditEntry.Source,
"amount": creditEntry.Amount,
"user_id": creditEntry.UserID,
"metadata": creditMetadata,
"created_at": goqu.L("now()"),
"updated_at": goqu.L("now()"),
}
if creditEntry.ID != "" {
creditRecord["id"] = creditEntry.ID
func (r BillingTransactionRepository) CreateEntry(ctx context.Context, debitEntry credit.Transaction,
creditEntry credit.Transaction) ([]credit.Transaction, error) {
txOpts := sql.TxOptions{
Isolation: sql.LevelRepeatableRead,
ReadOnly: false,
}

var creditReturnedEntry, debitReturnedEntry credit.Transaction
if err := r.dbc.WithTxn(ctx, sql.TxOptions{}, func(tx *sqlx.Tx) error {
// check if balance is enough if it's a customer entry
if customerAcc.ID != "" {
currentBalance, err := r.getBalanceInTx(ctx, tx, customerAcc.ID)
if err != nil {
return fmt.Errorf("failed to apply transaction: %w", err)
}
if err := isSufficientBalance(customerAcc.CreditMin, currentBalance, debitEntry.Amount); err != nil {
return err
}
}
err := r.withRetry(ctx, func() error {
return r.dbc.WithTxn(ctx, txOpts, func(tx *sqlx.Tx) error {
if debitEntry.CustomerID != schema.PlatformOrgID.String() {
currentBalance, err := r.getBalanceInTx(ctx, tx, debitEntry.CustomerID)
if err != nil {
return fmt.Errorf("failed to get balance: %w", err)
}

var debitModel Transaction
var creditModel Transaction
query, params, err := dialect.Insert(TABLE_BILLING_TRANSACTIONS).Rows(debitRecord).Returning(&Transaction{}).ToSQL()
if err != nil {
return fmt.Errorf("%w: %s", parseErr, err)
}
if err = r.dbc.WithTimeout(ctx, TABLE_BILLING_TRANSACTIONS, "Create", func(ctx context.Context) error {
return r.dbc.QueryRowxContext(ctx, query, params...).StructScan(&debitModel)
}); err != nil {
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && (pqErr.Code == "23505") { // handle unique key violations
if pqErr.ConstraintName == "billing_transactions_pkey" { // primary key violation
return credit.ErrAlreadyApplied
customerAcc, err := r.customerRepo.GetByID(ctx, debitEntry.CustomerID)
if err != nil {
return fmt.Errorf("failed to get customer account: %w", err)
}
// add other specific unique key violations here if needed
}
return fmt.Errorf("%w: %s", dbErr, err)
}

query, params, err = dialect.Insert(TABLE_BILLING_TRANSACTIONS).Rows(creditRecord).Returning(&Transaction{}).ToSQL()
if err != nil {
return fmt.Errorf("%w: %s", parseErr, err)
}
if err = r.dbc.WithTimeout(ctx, TABLE_BILLING_TRANSACTIONS, "Create", func(ctx context.Context) error {
return r.dbc.QueryRowxContext(ctx, query, params...).StructScan(&creditModel)
}); err != nil {
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && (pqErr.Code == "23505") { // handle unique key violations
if pqErr.ConstraintName == "billing_transactions_pkey" { // primary key violation
return credit.ErrAlreadyApplied
if err := isSufficientBalance(customerAcc.CreditMin, currentBalance, debitEntry.Amount); err != nil {
return err
}
// add other specific unique key violations here if needed
}
return fmt.Errorf("%w: %s", dbErr, err)
}

creditReturnedEntry, err = creditModel.transform()
if err != nil {
return fmt.Errorf("failed to transform credit entry: %w", err)
}
debitReturnedEntry, err = debitModel.transform()
if err != nil {
return fmt.Errorf("failed to transform debit entry: %w", err)
}
var debitModel Transaction
if err := r.createTransactionEntry(ctx, tx, debitEntry, &debitModel); err != nil {
return fmt.Errorf("failed to create debit entry: %w", err)
}

return nil
}); err != nil {
var creditModel Transaction
if err := r.createTransactionEntry(ctx, tx, creditEntry, &creditModel); err != nil {
return fmt.Errorf("failed to create credit entry: %w", err)
}

var err error
creditReturnedEntry, err = creditModel.transform()
if err != nil {
return fmt.Errorf("failed to transform credit entry: %w", err)
}
debitReturnedEntry, err = debitModel.transform()
if err != nil {
return fmt.Errorf("failed to transform debit entry: %w", err)
}

return nil
})
})
if err != nil {
if errors.Is(err, credit.ErrAlreadyApplied) {
return nil, credit.ErrAlreadyApplied
} else if errors.Is(err, credit.ErrInsufficientCredits) {
Expand All @@ -208,6 +174,50 @@ func (r BillingTransactionRepository) CreateEntry(ctx context.Context, debitEntr
return []credit.Transaction{debitReturnedEntry, creditReturnedEntry}, nil
}

func (r BillingTransactionRepository) createTransactionEntry(ctx context.Context, tx *sqlx.Tx, entry credit.Transaction, model *Transaction) error {
if entry.Metadata == nil {
entry.Metadata = make(map[string]any)
}
metadata, err := json.Marshal(entry.Metadata)
if err != nil {
return err
}

record := goqu.Record{
"account_id": entry.CustomerID,
"description": entry.Description,
"type": entry.Type,
"source": entry.Source,
"amount": entry.Amount,
"user_id": entry.UserID,
"metadata": metadata,
"created_at": goqu.L("now()"),
"updated_at": goqu.L("now()"),
}
if entry.ID != "" {
record["id"] = entry.ID
}

query, params, err := dialect.Insert(TABLE_BILLING_TRANSACTIONS).Rows(record).Returning(&Transaction{}).ToSQL()
if err != nil {
return fmt.Errorf("%w: %w", parseErr, err)
}

if err = r.dbc.WithTimeout(ctx, TABLE_BILLING_TRANSACTIONS, "Create", func(ctx context.Context) error {
return tx.QueryRowxContext(ctx, query, params...).StructScan(model)
}); err != nil {
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && (pqErr.Code == "23505") {
if pqErr.ConstraintName == "billing_transactions_pkey" {
return credit.ErrAlreadyApplied
}
}
return fmt.Errorf("%w: %w", dbErr, err)
}

return nil
}

// isSufficientBalance checks if the customer has enough balance to perform the transaction.
// If the customer has a credit min limit set, then a negative balance means loaner/overdraft limit and
// a positive limit mean at least that much balance should be there in the account.
Expand Down Expand Up @@ -328,6 +338,7 @@ func (r BillingTransactionRepository) getDebitBalance(ctx context.Context, tx *s
"account_id": accountID,
"type": credit.DebitType,
})

query, params, err := stmt.ToSQL()
if err != nil {
return nil, fmt.Errorf("%w: %s", parseErr, err)
Expand All @@ -347,6 +358,7 @@ func (r BillingTransactionRepository) getCreditBalance(ctx context.Context, tx *
"account_id": accountID,
"type": credit.CreditType,
})

query, params, err := stmt.ToSQL()
if err != nil {
return nil, fmt.Errorf("%w: %s", parseErr, err)
Expand Down Expand Up @@ -388,11 +400,17 @@ func (r BillingTransactionRepository) getBalanceInTx(ctx context.Context, tx *sq
// in transaction table till now.
func (r BillingTransactionRepository) GetBalance(ctx context.Context, accountID string) (int64, error) {
var amount int64
if err := r.dbc.WithTxn(ctx, sql.TxOptions{}, func(tx *sqlx.Tx) error {
var err error
amount, err = r.getBalanceInTx(ctx, tx, accountID)
return err
}); err != nil {
err := r.withRetry(ctx, func() error {
return r.dbc.WithTxn(ctx, sql.TxOptions{
Isolation: sql.LevelRepeatableRead,
ReadOnly: true,
}, func(tx *sqlx.Tx) error {
var err error
amount, err = r.getBalanceInTx(ctx, tx, accountID)
return err
})
})
if err != nil {
return 0, fmt.Errorf("failed to get balance: %w", err)
}
return amount, nil
Expand Down
52 changes: 52 additions & 0 deletions test/e2e/regression/billing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,58 @@ func (s *BillingRegressionTestSuite) TestUsageAPI() {
})
s.Assert().NoError(err)
})
s.Run("10. check for concurrent transactions", func() {
// check initial balance
getBalanceResp, err := s.testBench.Client.GetBillingBalance(ctxOrgAdminAuth, &frontierv1beta1.GetBillingBalanceRequest{
OrgId: createOrgResp.GetOrganization().GetId(),
Id: createBillingResp.GetBillingAccount().GetId(),
})
s.Assert().NoError(err)
beforeBalance := getBalanceResp.GetBalance().GetAmount()

// Create multiple concurrent usage requests
numRequests := 20
errChan := make(chan error, numRequests)
for i := 0; i < numRequests; i++ {
go func() {
_, err := s.testBench.Client.CreateBillingUsage(ctxOrgAdminAuth, &frontierv1beta1.CreateBillingUsageRequest{
OrgId: createOrgResp.GetOrganization().GetId(),
BillingId: createBillingResp.GetBillingAccount().GetId(),
Usages: []*frontierv1beta1.Usage{
{
Id: uuid.New().String(),
Source: "billing.test",
Amount: 2,
UserId: testUserID,
},
},
})
errChan <- err
}()
}

// Wait for all requests to complete
var successCount int
for i := 0; i < numRequests; i++ {
err := <-errChan
if err == nil {
successCount++
} else {
s.Assert().ErrorContains(err, credit.ErrInsufficientCredits.Error())
}
}

// Verify final balance
getBalanceResp, err = s.testBench.Client.GetBillingBalance(ctxOrgAdminAuth, &frontierv1beta1.GetBillingBalanceRequest{
OrgId: createOrgResp.GetOrganization().GetId(),
Id: createBillingResp.GetBillingAccount().GetId(),
})
s.Assert().NoError(err)

// Verify the balance was deducted exactly by successful transactions amount
expectedBalance := beforeBalance - int64(successCount*2)
s.Assert().Equal(expectedBalance, getBalanceResp.GetBalance().GetAmount())
})
}

func (s *BillingRegressionTestSuite) TestCheckFeatureEntitlementAPI() {
Expand Down
4 changes: 2 additions & 2 deletions test/e2e/testbench/testbench.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ func Init(appConfig *config.Frontier) (*TestBench, error) {
URL: connMainPGExternal,
MaxIdleConns: 10,
MaxOpenConns: 10,
ConnMaxLifeTime: time.Millisecond * 100,
MaxQueryTimeout: time.Millisecond * 100,
ConnMaxLifeTime: time.Second * 60,
MaxQueryTimeout: time.Second * 30,
}
appConfig.SpiceDB = spicedb.Config{
Host: "localhost",
Expand Down

0 comments on commit c1ec2e0

Please sign in to comment.