Skip to content

Commit

Permalink
Merge pull request #932 from ellemouton/context
Browse files Browse the repository at this point in the history
multi: thread contexts through properly
  • Loading branch information
ellemouton authored Jan 14, 2025
2 parents 0fa5112 + 76250ca commit 54cf58b
Show file tree
Hide file tree
Showing 31 changed files with 390 additions and 357 deletions.
32 changes: 16 additions & 16 deletions accounts/checkers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,8 +499,8 @@ func TestSendPaymentCalls(t *testing.T) {

func testSendPayment(t *testing.T, uri string) {
var (
parentCtx = context.Background()
zeroFee = &lnrpc.FeeLimit{Limit: &lnrpc.FeeLimit_Fixed{
ctx = context.Background()
zeroFee = &lnrpc.FeeLimit{Limit: &lnrpc.FeeLimit_Fixed{
Fixed: 0,
}}
requestID uint64
Expand All @@ -520,7 +520,7 @@ func testSendPayment(t *testing.T, uri string) {
service, err := NewService(t.TempDir(), errFunc)
require.NoError(t, err)

err = service.Start(lndMock, routerMock, chainParams)
err = service.Start(ctx, lndMock, routerMock, chainParams)
require.NoError(t, err)

assertBalance := func(id AccountID, expectedBalance int64) {
Expand All @@ -533,7 +533,7 @@ func testSendPayment(t *testing.T, uri string) {

// This should error because there is no account in the context.
err = service.checkers.checkIncomingRequest(
parentCtx, uri, &lnrpc.SendRequest{},
ctx, uri, &lnrpc.SendRequest{},
)
require.ErrorContains(t, err, "no account found in context")

Expand All @@ -543,7 +543,7 @@ func testSendPayment(t *testing.T, uri string) {
)
require.NoError(t, err)

ctxWithAcct := AddAccountToContext(parentCtx, acct)
ctxWithAcct := AddAccountToContext(ctx, acct)

// This should error because there is no request ID in the context.
err = service.checkers.checkIncomingRequest(
Expand All @@ -552,7 +552,7 @@ func testSendPayment(t *testing.T, uri string) {
require.ErrorContains(t, err, "no request ID found in context")

reqID1 := nextRequestID()
ctx := AddRequestIDToContext(ctxWithAcct, reqID1)
ctx = AddRequestIDToContext(ctxWithAcct, reqID1)

// This should error because no payment hash is provided.
err = service.checkers.checkIncomingRequest(
Expand Down Expand Up @@ -698,7 +698,7 @@ func testSendPayment(t *testing.T, uri string) {
func TestSendPaymentV2(t *testing.T) {
var (
uri = "/routerrpc.Router/SendPaymentV2"
parentCtx = context.Background()
ctx = context.Background()
requestID uint64
)

Expand All @@ -716,7 +716,7 @@ func TestSendPaymentV2(t *testing.T) {
service, err := NewService(t.TempDir(), errFunc)
require.NoError(t, err)

err = service.Start(lndMock, routerMock, chainParams)
err = service.Start(ctx, lndMock, routerMock, chainParams)
require.NoError(t, err)

assertBalance := func(id AccountID, expectedBalance int64) {
Expand All @@ -729,7 +729,7 @@ func TestSendPaymentV2(t *testing.T) {

// This should error because there is no account in the context.
err = service.checkers.checkIncomingRequest(
parentCtx, uri, &routerrpc.SendPaymentRequest{},
ctx, uri, &routerrpc.SendPaymentRequest{},
)
require.ErrorContains(t, err, "no account found in context")

Expand All @@ -739,7 +739,7 @@ func TestSendPaymentV2(t *testing.T) {
)
require.NoError(t, err)

ctxWithAcct := AddAccountToContext(parentCtx, acct)
ctxWithAcct := AddAccountToContext(ctx, acct)

// This should error because there is no request ID in the context.
err = service.checkers.checkIncomingRequest(
Expand All @@ -748,7 +748,7 @@ func TestSendPaymentV2(t *testing.T) {
require.ErrorContains(t, err, "no request ID found in context")

reqID1 := nextRequestID()
ctx := AddRequestIDToContext(ctxWithAcct, reqID1)
ctx = AddRequestIDToContext(ctxWithAcct, reqID1)

// This should error because no payment hash is provided.
err = service.checkers.checkIncomingRequest(
Expand Down Expand Up @@ -885,7 +885,7 @@ func TestSendPaymentV2(t *testing.T) {
func TestSendToRouteV2(t *testing.T) {
var (
uri = "/routerrpc.Router/SendToRouteV2"
parentCtx = context.Background()
ctx = context.Background()
requestID uint64
)

Expand All @@ -903,7 +903,7 @@ func TestSendToRouteV2(t *testing.T) {
service, err := NewService(t.TempDir(), errFunc)
require.NoError(t, err)

err = service.Start(lndMock, routerMock, chainParams)
err = service.Start(ctx, lndMock, routerMock, chainParams)
require.NoError(t, err)

assertBalance := func(id AccountID, expectedBalance int64) {
Expand All @@ -916,7 +916,7 @@ func TestSendToRouteV2(t *testing.T) {

// This should error because there is no account in the context.
err = service.checkers.checkIncomingRequest(
parentCtx, uri, &routerrpc.SendToRouteRequest{},
ctx, uri, &routerrpc.SendToRouteRequest{},
)
require.ErrorContains(t, err, "no account found in context")

Expand All @@ -926,7 +926,7 @@ func TestSendToRouteV2(t *testing.T) {
)
require.NoError(t, err)

ctxWithAcct := AddAccountToContext(parentCtx, acct)
ctxWithAcct := AddAccountToContext(ctx, acct)

// This should error because there is no request ID in the context.
err = service.checkers.checkIncomingRequest(
Expand All @@ -935,7 +935,7 @@ func TestSendToRouteV2(t *testing.T) {
require.ErrorContains(t, err, "no request ID found in context")

reqID1 := nextRequestID()
ctx := AddRequestIDToContext(ctxWithAcct, reqID1)
ctx = AddRequestIDToContext(ctxWithAcct, reqID1)

// This should error because no payment hash is provided.
err = service.checkers.checkIncomingRequest(
Expand Down
19 changes: 10 additions & 9 deletions accounts/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/btcsuite/btcd/chaincfg"
"github.com/lightninglabs/lndclient"
"github.com/lightninglabs/taproot-assets/fn"
"github.com/lightningnetwork/lnd/channeldb"
invpkg "github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/lnrpc"
Expand Down Expand Up @@ -55,7 +56,7 @@ type InterceptorService struct {
routerClient lndclient.RouterClient

mainCtx context.Context
contextCancel context.CancelFunc
contextCancel fn.Option[context.CancelFunc]

requestMtx sync.Mutex
checkers *AccountChecker
Expand Down Expand Up @@ -85,12 +86,8 @@ func NewService(dir string,
return nil, err
}

mainCtx, contextCancel := context.WithCancel(context.Background())

return &InterceptorService{
store: accountStore,
mainCtx: mainCtx,
contextCancel: contextCancel,
invoiceToAccount: make(map[lntypes.Hash]AccountID),
pendingPayments: make(map[lntypes.Hash]*trackedPayment),
requestValuesStore: newRequestValuesStore(),
Expand All @@ -101,9 +98,14 @@ func NewService(dir string,
}

// Start starts the account service and its interceptor capability.
func (s *InterceptorService) Start(lightningClient lndclient.LightningClient,
func (s *InterceptorService) Start(ctx context.Context,
lightningClient lndclient.LightningClient,
routerClient lndclient.RouterClient, params *chaincfg.Params) error {

mainCtx, contextCancel := context.WithCancel(ctx)
s.mainCtx = mainCtx
s.contextCancel = fn.Some(contextCancel)

s.routerClient = routerClient
s.checkers = NewAccountChecker(s, params)

Expand Down Expand Up @@ -180,7 +182,7 @@ func (s *InterceptorService) Start(lightningClient lndclient.LightningClient,
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer s.contextCancel()
defer contextCancel()

for {
select {
Expand Down Expand Up @@ -235,9 +237,8 @@ func (s *InterceptorService) Stop() error {
s.requestMtx.Lock()
defer s.requestMtx.Unlock()

s.contextCancel()
s.contextCancel.WhenSome(func(fn context.CancelFunc) { fn() })
close(s.quit)

s.wg.Wait()

return s.store.Close()
Expand Down
5 changes: 4 additions & 1 deletion accounts/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,10 @@ func TestAccountService(t *testing.T) {
}

// Any errors during startup expected?
err = service.Start(lndMock, routerMock, chainParams)
err = service.Start(
context.Background(), lndMock, routerMock,
chainParams,
)
if tc.startupErr != "" {
require.ErrorContains(tt, err, tc.startupErr)

Expand Down
24 changes: 14 additions & 10 deletions autopilotserver/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightninglabs/lightning-terminal/autopilotserverrpc"
"github.com/lightninglabs/taproot-assets/fn"
"github.com/lightningnetwork/lnd/tor"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -88,8 +89,9 @@ type Client struct {

featurePerms *featurePerms

quit chan struct{}
wg sync.WaitGroup
quit chan struct{}
wg sync.WaitGroup
cancel fn.Option[context.CancelFunc]
}

type session struct {
Expand Down Expand Up @@ -124,16 +126,19 @@ func NewClient(cfg *Config) (Autopilot, error) {
}

// Start kicks off all the goroutines required by the Client.
func (c *Client) Start(opts ...func(cfg *Config)) error {
func (c *Client) Start(ctx context.Context, opts ...func(cfg *Config)) error {
var startErr error
c.start.Do(func() {
log.Infof("Starting Autopilot Client")

ctx, cancel := context.WithCancel(ctx)
c.cancel = fn.Some(cancel)

for _, o := range opts {
o(c.cfg)
}

version, err := c.getMinVersion(context.Background())
version, err := c.getMinVersion(ctx)
if err != nil {
startErr = err
return
Expand All @@ -154,8 +159,8 @@ func (c *Client) Start(opts ...func(cfg *Config)) error {
}

c.wg.Add(2)
go c.activateSessionsForever()
go c.updateFeaturePermsForever()
go c.activateSessionsForever(ctx)
go c.updateFeaturePermsForever(ctx)
})

return startErr
Expand All @@ -164,6 +169,7 @@ func (c *Client) Start(opts ...func(cfg *Config)) error {
// Stop cleans up any resources or goroutines managed by the Client.
func (c *Client) Stop() {
c.stop.Do(func() {
c.cancel.WhenSome(func(fn context.CancelFunc) { fn() })
close(c.quit)
c.wg.Wait()
})
Expand Down Expand Up @@ -222,10 +228,9 @@ func (c *Client) SessionRevoked(ctx context.Context, pubKey *btcec.PublicKey) {

// activateSessionsForever periodically ensures that each of our active
// autopilot sessions are known by the autopilot to be active.
func (c *Client) activateSessionsForever() {
func (c *Client) activateSessionsForever(ctx context.Context) {
defer c.wg.Done()

ctx := context.Background()
ticker := time.NewTicker(c.cfg.PingCadence)
defer ticker.Stop()

Expand Down Expand Up @@ -273,10 +278,9 @@ func (c *Client) activateSessionsForever() {
// feature permissions list.
//
// NOTE: this MUST be called in a goroutine.
func (c *Client) updateFeaturePermsForever() {
func (c *Client) updateFeaturePermsForever(ctx context.Context) {
defer c.wg.Done()

ctx := context.Background()
ticker := time.NewTicker(time.Second)
defer ticker.Stop()

Expand Down
2 changes: 1 addition & 1 deletion autopilotserver/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestAutopilotClient(t *testing.T) {
PingCadence: time.Second,
})
require.NoError(t, err)
require.NoError(t, client.Start())
require.NoError(t, client.Start(ctx))
t.Cleanup(client.Stop)

privKey, err := btcec.NewPrivateKey()
Expand Down
2 changes: 1 addition & 1 deletion autopilotserver/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ type Autopilot interface {
SessionRevoked(ctx context.Context, key *btcec.PublicKey)

// Start kicks off the goroutines of the client.
Start(opts ...func(cfg *Config)) error
Start(ctx context.Context, opts ...func(cfg *Config)) error

// Stop cleans up any resources held by the client.
Stop()
Expand Down
Loading

0 comments on commit 54cf58b

Please sign in to comment.