From db45e3afe3eecce727f843f267944ce342ae82bb Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 7 Jan 2025 17:51:11 +0200 Subject: [PATCH 1/8] terminal: thread one context through --- terminal.go | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/terminal.go b/terminal.go index b27e3b2c5..148fe27a6 100644 --- a/terminal.go +++ b/terminal.go @@ -236,6 +236,8 @@ func New() *LightningTerminal { // Run starts everything and then blocks until either the application is shut // down or a critical error happens. func (g *LightningTerminal) Run() error { + ctx := context.TODO() + // Hook interceptor for os signals. shutdownInterceptor, err := signal.Intercept() if err != nil { @@ -345,7 +347,7 @@ func (g *LightningTerminal) Run() error { // We'll also create a REST proxy that'll convert any REST calls to gRPC // calls and forward them to the internal listener. if g.cfg.EnableREST { - if err := g.createRESTProxy(); err != nil { + if err := g.createRESTProxy(ctx); err != nil { return fmt.Errorf("error creating REST proxy: %v", err) } } @@ -353,7 +355,7 @@ func (g *LightningTerminal) Run() error { // Attempt to start Lit and all of its sub-servers. If an error is // returned, it means that either one of Lit's internal sub-servers // could not start or LND could not start or be connected to. - startErr := g.start() + startErr := g.start(ctx) if startErr != nil { g.statusMgr.SetErrored( subservers.LIT, "could not start Lit: %v", startErr, @@ -380,7 +382,7 @@ func (g *LightningTerminal) Run() error { // If any of the sub-servers managed by the subServerMgr error while starting // up, these are considered non-fatal and will not result in an error being // returned. -func (g *LightningTerminal) start() error { +func (g *LightningTerminal) start(ctx context.Context) error { var err error accountServiceErrCallback := func(err error) { @@ -658,7 +660,7 @@ func (g *LightningTerminal) start() error { // Now that we have started the main UI web server, show some useful // information to the user so they can access the web UI easily. - if err := g.showStartupInfo(); err != nil { + if err := g.showStartupInfo(ctx); err != nil { return fmt.Errorf("error displaying startup info: %v", err) } @@ -713,7 +715,7 @@ func (g *LightningTerminal) start() error { } // Set up all the LND clients required by LiT. - err = g.setUpLNDClients(lndQuit) + err = g.setUpLNDClients(ctx, lndQuit) if err != nil { g.statusMgr.SetErrored( subservers.LND, "could not set up LND clients: %v", err, @@ -773,7 +775,9 @@ func (g *LightningTerminal) basicLNDClient() (lnrpc.LightningClient, error) { } // setUpLNDClients sets up the various LND clients required by LiT. -func (g *LightningTerminal) setUpLNDClients(lndQuit chan struct{}) error { +func (g *LightningTerminal) setUpLNDClients(ctx context.Context, + lndQuit chan struct{}) error { + var ( err error insecure bool @@ -873,7 +877,7 @@ func (g *LightningTerminal) setUpLNDClients(lndQuit chan struct{}) error { // subservers. This will just block until lnd signals readiness. But we // still want to react to shutdown requests, so we need to listen for // those. - ctxc, cancel := context.WithCancel(context.Background()) + ctxc, cancel := context.WithCancel(ctx) defer cancel() // Make sure the context is canceled if the user requests shutdown. @@ -940,7 +944,6 @@ func (g *LightningTerminal) setUpLNDClients(lndQuit chan struct{}) error { // Create a super macaroon that can be used to control lnd, // faraday, loop, and pool, all at the same time. log.Infof("Baking internal super macaroon") - ctx := context.Background() superMacaroon, err := BakeSuperMacaroon( ctx, g.basicClient, session.NewSuperMacaroonRootKeyID( [4]byte{}, @@ -1657,7 +1660,7 @@ func (g *LightningTerminal) startMainWebServer() error { // createRESTProxy creates a grpc-gateway based REST proxy that takes any call // identified as a REST call, converts it to a gRPC request and forwards it to // our local main server for further triage/forwarding. -func (g *LightningTerminal) createRESTProxy() error { +func (g *LightningTerminal) createRESTProxy(ctx context.Context) error { // The default JSON marshaler of the REST proxy only sets OrigName to // true, which instructs it to use the same field names as specified in // the proto file and not switch to camel case. What we also want is @@ -1700,7 +1703,7 @@ func (g *LightningTerminal) createRESTProxy() error { // wildcard to prevent certificate issues when accessing the proxy // externally. restMux := restProxy.NewServeMux(customMarshalerOption) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) g.restCancel = cancel // Enable WebSocket and CORS support as well. A request will pass @@ -1926,7 +1929,7 @@ func allowCORS(handler http.Handler, origins []string) http.Handler { // showStartupInfo shows useful information to the user to easily access the // web UI that was just started. -func (g *LightningTerminal) showStartupInfo() error { +func (g *LightningTerminal) showStartupInfo(ctx context.Context) error { info := struct { mode string status string @@ -1958,7 +1961,6 @@ func (g *LightningTerminal) showStartupInfo() error { return fmt.Errorf("error querying remote node: %v", err) } - ctx := context.Background() res, err := basicClient.GetInfo(ctx, &lnrpc.GetInfoRequest{}) if err != nil { if !lndclient.IsUnlockError(err) { From 1fc0e6f16d1485b1b88324b912cf3db5a13abff5 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 7 Jan 2025 17:57:31 +0200 Subject: [PATCH 2/8] terminal: use ctx instead of interceptor --- cmd/litd/main.go | 3 ++- terminal.go | 58 +++++++++++++++++++++++------------------------- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/cmd/litd/main.go b/cmd/litd/main.go index 7e4f655d4..238618441 100644 --- a/cmd/litd/main.go +++ b/cmd/litd/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "fmt" "os" @@ -11,7 +12,7 @@ import ( // main starts the lightning-terminal application. func main() { - err := terminal.New().Run() + err := terminal.New().Run(context.Background()) var flagErr *flags.Error isFlagErr := errors.As(err, &flagErr) if err != nil && (!isFlagErr || flagErr.Type != flags.ErrHelp) { diff --git a/terminal.go b/terminal.go index 148fe27a6..39a5555aa 100644 --- a/terminal.go +++ b/terminal.go @@ -235,15 +235,31 @@ func New() *LightningTerminal { // Run starts everything and then blocks until either the application is shut // down or a critical error happens. -func (g *LightningTerminal) Run() error { - ctx := context.TODO() - +func (g *LightningTerminal) Run(ctx context.Context) error { // Hook interceptor for os signals. shutdownInterceptor, err := signal.Intercept() if err != nil { return fmt.Errorf("could not intercept signals: %v", err) } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Make sure the context is canceled if the user requests shutdown and + // that the shutdown signal is requested if the context is canceled. + go func() { + select { + // Client requests shutdown, cancel the wait. + case <-shutdownInterceptor.ShutdownChannel(): + cancel() + + // The check was completed and the above defer canceled the + // context. We can just exit the goroutine, nothing more to do. + case <-ctx.Done(): + shutdownInterceptor.RequestShutdown() + } + }() + cfg, err := loadAndValidateConfig(shutdownInterceptor) if err != nil { return fmt.Errorf("could not load config: %w", err) @@ -601,8 +617,8 @@ func (g *LightningTerminal) start(ctx context.Context) error { return fmt.Errorf("LND has stopped") - case <-interceptor.ShutdownChannel(): - return fmt.Errorf("received the shutdown signal") + case <-ctx.Done(): + return ctx.Err() } // Connect to LND. @@ -683,8 +699,8 @@ func (g *LightningTerminal) start(ctx context.Context) error { return fmt.Errorf("LND has stopped") - case <-interceptor.ShutdownChannel(): - return fmt.Errorf("received the shutdown signal") + case <-ctx.Done(): + return ctx.Err() } } @@ -758,7 +774,7 @@ func (g *LightningTerminal) start(ctx context.Context) error { return fmt.Errorf("LND is not running") - case <-interceptor.ShutdownChannel(): + case <-ctx.Done(): log.Infof("Shutdown signal received") } @@ -812,8 +828,8 @@ func (g *LightningTerminal) setUpLNDClients(ctx context.Context, case <-lndQuit: return fmt.Errorf("LND has stopped") - case <-interceptor.ShutdownChannel(): - return fmt.Errorf("received the shutdown signal") + case <-ctx.Done(): + return ctx.Err() case <-time.After(g.cfg.LndConnectInterval): return nil @@ -874,25 +890,7 @@ func (g *LightningTerminal) setUpLNDClients(ctx context.Context, // wallet being fully synced to its chain backend. The chain notifier // will always be ready first so if we instruct the lndclient to wait // for the wallet sync, we should be fully ready to start all our - // subservers. This will just block until lnd signals readiness. But we - // still want to react to shutdown requests, so we need to listen for - // those. - ctxc, cancel := context.WithCancel(ctx) - defer cancel() - - // Make sure the context is canceled if the user requests shutdown. - go func() { - select { - // Client requests shutdown, cancel the wait. - case <-interceptor.ShutdownChannel(): - cancel() - - // The check was completed and the above defer canceled the - // context. We can just exit the goroutine, nothing more to do. - case <-ctxc.Done(): - } - }() - + // subservers. This will just block until lnd signals readiness. log.Infof("Connecting full lnd client") for { g.lndClient, err = lndclient.NewLndServices( @@ -907,7 +905,7 @@ func (g *LightningTerminal) setUpLNDClients(ctx context.Context, ), BlockUntilChainSynced: true, BlockUntilUnlocked: true, - CallerCtx: ctxc, + CallerCtx: ctx, CheckVersion: minimalCompatibleVersion, RPCTimeout: g.cfg.LndRPCTimeout, ChainSyncPollInterval: g.cfg.LndConnectInterval, From 0503cfd433ff892ee3852e3eca97fc8347ae400a Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 13 Jan 2025 06:57:08 +0200 Subject: [PATCH 3/8] session_rpcsever: thread context through --- session_rpcserver.go | 22 +++++++++++----------- terminal.go | 6 +++--- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/session_rpcserver.go b/session_rpcserver.go index 2515c34a1..3ce968157 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -98,7 +98,7 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer, // start all the components necessary for the sessionRpcServer to start serving // requests. This includes resuming all non-revoked sessions. -func (s *sessionRpcServer) start() error { +func (s *sessionRpcServer) start(ctx context.Context) error { // Start up all previously created sessions. sessions, err := s.cfg.db.ListSessions(nil) if err != nil { @@ -135,7 +135,6 @@ func (s *sessionRpcServer) start() error { continue } - ctx := context.Background() ctxc, cancel := context.WithTimeout( ctx, defaultConnectTimeout, ) @@ -147,7 +146,7 @@ func (s *sessionRpcServer) start() error { cancel() if err != nil { log.Errorf("error activating autopilot "+ - "session (%x) with the client", key, + "session (%x) with the client: %v", key, err) if perm { @@ -164,7 +163,7 @@ func (s *sessionRpcServer) start() error { } } - if err := s.resumeSession(sess); err != nil { + if err := s.resumeSession(ctx, sess); err != nil { log.Errorf("error resuming session (%x): %v", key, err) } } @@ -190,7 +189,7 @@ func (s *sessionRpcServer) stop() error { } // AddSession adds and starts a new Terminal Connect session. -func (s *sessionRpcServer) AddSession(_ context.Context, +func (s *sessionRpcServer) AddSession(ctx context.Context, req *litrpc.AddSessionRequest) (*litrpc.AddSessionResponse, error) { expiry := time.Unix(int64(req.ExpiryTimestampSeconds), 0) @@ -335,7 +334,7 @@ func (s *sessionRpcServer) AddSession(_ context.Context, return nil, fmt.Errorf("error storing session: %v", err) } - if err := s.resumeSession(sess); err != nil { + if err := s.resumeSession(ctx, sess); err != nil { return nil, fmt.Errorf("error starting session: %v", err) } @@ -351,7 +350,9 @@ func (s *sessionRpcServer) AddSession(_ context.Context, // resumeSession tries to start an existing session if it is not expired, not // revoked and a LiT session. -func (s *sessionRpcServer) resumeSession(sess *session.Session) error { +func (s *sessionRpcServer) resumeSession(ctx context.Context, + sess *session.Session) error { + pubKey := sess.LocalPublicKey pubKeyBytes := pubKey.SerializeCompressed() @@ -423,7 +424,7 @@ func (s *sessionRpcServer) resumeSession(sess *session.Session) error { }) mac, err := s.cfg.superMacBaker( - context.Background(), sess.MacaroonRootKey, + ctx, sess.MacaroonRootKey, &session.MacaroonRecipe{ Permissions: permissions, Caveats: caveats, @@ -431,7 +432,7 @@ func (s *sessionRpcServer) resumeSession(sess *session.Session) error { ) if err != nil { log.Debugf("Not resuming session %x. Could not bake "+ - "the necessary macaroon: %w", pubKeyBytes, err) + "the necessary macaroon: %v", pubKeyBytes, err) return nil } @@ -516,7 +517,6 @@ func (s *sessionRpcServer) resumeSession(sess *session.Session) error { } if s.cfg.autopilot != nil { - ctx := context.Background() ctxc, cancel := context.WithTimeout( ctx, defaultConnectTimeout, ) @@ -1246,7 +1246,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, return nil, fmt.Errorf("error storing session: %v", err) } - if err := s.resumeSession(sess); err != nil { + if err := s.resumeSession(ctx, sess); err != nil { return nil, fmt.Errorf("error starting session: %v", err) } diff --git a/terminal.go b/terminal.go index 39a5555aa..2a9cbb316 100644 --- a/terminal.go +++ b/terminal.go @@ -751,7 +751,7 @@ func (g *LightningTerminal) start(ctx context.Context) error { g.basicClient, g.lndClient, createDefaultMacaroons, ) - err = g.startInternalSubServers(!g.cfg.statelessInitMode) + err = g.startInternalSubServers(ctx, !g.cfg.statelessInitMode) if err != nil { return fmt.Errorf("could not start litd sub-servers: %v", err) } @@ -959,7 +959,7 @@ func (g *LightningTerminal) setUpLNDClients(ctx context.Context, } // startInternalSubServers starts all Litd specific sub-servers. -func (g *LightningTerminal) startInternalSubServers( +func (g *LightningTerminal) startInternalSubServers(ctx context.Context, createDefaultMacaroons bool) error { log.Infof("Starting LiT macaroon service") @@ -1012,7 +1012,7 @@ func (g *LightningTerminal) startInternalSubServers( } log.Infof("Starting LiT session server") - if err = g.sessionRpcServer.start(); err != nil { + if err = g.sessionRpcServer.start(ctx); err != nil { return err } g.sessionRpcServerStarted = true From 21983bab7564833fb78a98ad0d54f2baedbbce09 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 13 Jan 2025 07:02:42 +0200 Subject: [PATCH 4/8] rules: thread context through --- firewall/rule_enforcer.go | 18 +++++++++--------- rules/chan_policy_bounds.go | 4 ++-- rules/channel_constraints.go | 4 ++-- rules/channel_restrictions.go | 13 +++++++------ rules/channel_restrictions_test.go | 2 +- rules/history_limit.go | 4 ++-- rules/interfaces.go | 3 ++- rules/manager_set.go | 5 +++-- rules/onchain_budget.go | 4 ++-- rules/peer_restrictions.go | 26 ++++++++++++++------------ rules/peer_restrictions_test.go | 2 +- rules/rate_limit.go | 4 ++-- 12 files changed, 47 insertions(+), 42 deletions(-) diff --git a/firewall/rule_enforcer.go b/firewall/rule_enforcer.go index e68f3781a..964baf32a 100644 --- a/firewall/rule_enforcer.go +++ b/firewall/rule_enforcer.go @@ -238,7 +238,7 @@ func (r *RuleEnforcer) handleRequest(ctx context.Context, return nil, fmt.Errorf("could not extract ID from macaroon") } - rules, err := r.collectEnforcers(ri, sessionID) + rules, err := r.collectEnforcers(ctx, ri, sessionID) if err != nil { return nil, fmt.Errorf("error parsing rules: %v", err) } @@ -294,7 +294,7 @@ func (r *RuleEnforcer) handleResponse(ctx context.Context, return nil, fmt.Errorf("could not extract ID from macaroon") } - enforcers, err := r.collectEnforcers(ri, sessionID) + enforcers, err := r.collectEnforcers(ctx, ri, sessionID) if err != nil { return nil, fmt.Errorf("error parsing rules: %v", err) } @@ -328,7 +328,7 @@ func (r *RuleEnforcer) handleErrorResponse(ctx context.Context, return nil, fmt.Errorf("could not extract ID from macaroon") } - enforcers, err := r.collectEnforcers(ri, sessionID) + enforcers, err := r.collectEnforcers(ctx, ri, sessionID) if err != nil { return nil, fmt.Errorf("error parsing rules: %v", err) } @@ -353,7 +353,7 @@ func (r *RuleEnforcer) handleErrorResponse(ctx context.Context, // collectRule initialises and returns all the Rules that need to be enforced // for the given request. -func (r *RuleEnforcer) collectEnforcers(ri *RequestInfo, +func (r *RuleEnforcer) collectEnforcers(ctx context.Context, ri *RequestInfo, sessionID session.ID) ([]rules.Enforcer, error) { ruleEnforcers := make( @@ -363,8 +363,8 @@ func (r *RuleEnforcer) collectEnforcers(ri *RequestInfo, for rule, value := range ri.Rules.FeatureRules[ri.MetaInfo.Feature] { r, err := r.initRule( - ri.RequestID, rule, []byte(value), ri.MetaInfo.Feature, - sessionID, false, ri.WithPrivacy, + ctx, ri.RequestID, rule, []byte(value), + ri.MetaInfo.Feature, sessionID, false, ri.WithPrivacy, ) if err != nil { return nil, err @@ -377,8 +377,8 @@ func (r *RuleEnforcer) collectEnforcers(ri *RequestInfo, } // initRule initialises a rule.Rule with any required config values. -func (r *RuleEnforcer) initRule(reqID uint64, name string, value []byte, - featureName string, sessionID session.ID, +func (r *RuleEnforcer) initRule(ctx context.Context, reqID uint64, name string, + value []byte, featureName string, sessionID session.ID, sessionRule, privacy bool) (rules.Enforcer, error) { ruleValues, err := r.ruleMgrs.InitRuleValues(name, value) @@ -425,5 +425,5 @@ func (r *RuleEnforcer) initRule(reqID uint64, name string, value []byte, LndConnID: r.lndConnID, } - return r.ruleMgrs.InitEnforcer(cfg, name, ruleValues) + return r.ruleMgrs.InitEnforcer(ctx, cfg, name, ruleValues) } diff --git a/rules/chan_policy_bounds.go b/rules/chan_policy_bounds.go index 7ad3159c1..9ba90ded6 100644 --- a/rules/chan_policy_bounds.go +++ b/rules/chan_policy_bounds.go @@ -39,8 +39,8 @@ func (b *ChanPolicyBoundsMgr) Stop() error { // values and config. // // NOTE: This is part of the Manager interface. -func (b *ChanPolicyBoundsMgr) NewEnforcer(_ Config, values Values) (Enforcer, - error) { +func (b *ChanPolicyBoundsMgr) NewEnforcer(_ context.Context, _ Config, + values Values) (Enforcer, error) { bounds, ok := values.(*ChanPolicyBounds) if !ok { diff --git a/rules/channel_constraints.go b/rules/channel_constraints.go index 8e287b899..e50e30df3 100644 --- a/rules/channel_constraints.go +++ b/rules/channel_constraints.go @@ -38,8 +38,8 @@ func (m *ChanConstraintMgr) Stop() error { // values and config. // // NOTE: This is part of the Manager interface. -func (m *ChanConstraintMgr) NewEnforcer(_ Config, values Values) (Enforcer, - error) { +func (m *ChanConstraintMgr) NewEnforcer(_ context.Context, _ Config, + values Values) (Enforcer, error) { bounds, ok := values.(*ChannelConstraint) if !ok { diff --git a/rules/channel_restrictions.go b/rules/channel_restrictions.go index 8594dde5e..745ed85be 100644 --- a/rules/channel_restrictions.go +++ b/rules/channel_restrictions.go @@ -60,8 +60,8 @@ func (c *ChannelRestrictMgr) Stop() error { // values and config. // // NOTE: This is part of the Manager interface. -func (c *ChannelRestrictMgr) NewEnforcer(cfg Config, values Values) (Enforcer, - error) { +func (c *ChannelRestrictMgr) NewEnforcer(ctx context.Context, cfg Config, + values Values) (Enforcer, error) { channels, ok := values.(*ChannelRestrict) if !ok { @@ -72,7 +72,8 @@ func (c *ChannelRestrictMgr) NewEnforcer(cfg Config, values Values) (Enforcer, chanMap := make(map[uint64]bool, len(channels.DenyList)) for _, chanID := range channels.DenyList { chanMap[chanID] = true - if err := c.maybeUpdateChannelMaps(cfg, chanID); err != nil { + err := c.maybeUpdateChannelMaps(ctx, cfg, chanID) + if err != nil { return nil, err } } @@ -118,8 +119,8 @@ func (c *ChannelRestrictMgr) EmptyValue() Values { // maybeUpdateChannelMaps updates the ChannelRestrictMgrs set of known channels // iff the channel given by the caller is not found in the current map set. -func (c *ChannelRestrictMgr) maybeUpdateChannelMaps(cfg Config, - chanID uint64) error { +func (c *ChannelRestrictMgr) maybeUpdateChannelMaps(ctx context.Context, + cfg Config, chanID uint64) error { c.mu.Lock() defer c.mu.Unlock() @@ -133,7 +134,7 @@ func (c *ChannelRestrictMgr) maybeUpdateChannelMaps(cfg Config, // Fetch a list of our open channels from LND. lnd := cfg.GetLndClient() - chans, err := lnd.ListChannels(context.Background(), false, false) + chans, err := lnd.ListChannels(ctx, false, false) if err != nil { return err } diff --git a/rules/channel_restrictions_test.go b/rules/channel_restrictions_test.go index d6ef6e8c8..a12c80916 100644 --- a/rules/channel_restrictions_test.go +++ b/rules/channel_restrictions_test.go @@ -53,7 +53,7 @@ func TestChannelRestrictCheckRequest(t *testing.T) { }, }, } - enf, err := mgr.NewEnforcer(cfg, &ChannelRestrict{ + enf, err := mgr.NewEnforcer(ctx, cfg, &ChannelRestrict{ DenyList: []uint64{ chanID1, chanID2, }, diff --git a/rules/history_limit.go b/rules/history_limit.go index 8ca0270fa..dccebef44 100644 --- a/rules/history_limit.go +++ b/rules/history_limit.go @@ -38,8 +38,8 @@ func (h *HistoryLimitMgr) Stop() error { // values and config. // // NOTE: This is part of the Manager interface. -func (h *HistoryLimitMgr) NewEnforcer(_ Config, values Values) (Enforcer, - error) { +func (h *HistoryLimitMgr) NewEnforcer(_ context.Context, _ Config, + values Values) (Enforcer, error) { limit, ok := values.(*HistoryLimit) if !ok { diff --git a/rules/interfaces.go b/rules/interfaces.go index 66fc27bfd..a1683c4c5 100644 --- a/rules/interfaces.go +++ b/rules/interfaces.go @@ -16,7 +16,8 @@ import ( type Manager interface { // NewEnforcer constructs a new rule enforcer using the passed values // and config. - NewEnforcer(cfg Config, values Values) (Enforcer, error) + NewEnforcer(ctx context.Context, cfg Config, values Values) (Enforcer, + error) // NewValueFromProto converts the given proto value into a Value object. NewValueFromProto(p *litrpc.RuleValue) (Values, error) diff --git a/rules/manager_set.go b/rules/manager_set.go index edc9497dc..fce0c95e7 100644 --- a/rules/manager_set.go +++ b/rules/manager_set.go @@ -1,6 +1,7 @@ package rules import ( + "context" "encoding/json" "fmt" @@ -32,7 +33,7 @@ func NewRuleManagerSet() ManagerSet { // InitEnforcer gets the appropriate rule Manager for the given name and uses it // to create an appropriate rule Enforcer. -func (m ManagerSet) InitEnforcer(cfg Config, name string, +func (m ManagerSet) InitEnforcer(ctx context.Context, cfg Config, name string, values Values) (Enforcer, error) { mgr, ok := m[name] @@ -41,7 +42,7 @@ func (m ManagerSet) InitEnforcer(cfg Config, name string, name) } - return mgr.NewEnforcer(cfg, values) + return mgr.NewEnforcer(ctx, cfg, values) } // GetAllRules returns a map of names of all the rules supported by rule diff --git a/rules/onchain_budget.go b/rules/onchain_budget.go index 7cfab158b..1024dd08e 100644 --- a/rules/onchain_budget.go +++ b/rules/onchain_budget.go @@ -63,8 +63,8 @@ func (o *OnChainBudgetMgr) Stop() error { // passed values and config. // // NOTE: This is part of the Manager interface. -func (o *OnChainBudgetMgr) NewEnforcer(cfg Config, values Values) (Enforcer, - error) { +func (o *OnChainBudgetMgr) NewEnforcer(_ context.Context, cfg Config, + values Values) (Enforcer, error) { budget, ok := values.(*OnChainBudget) if !ok { diff --git a/rules/peer_restrictions.go b/rules/peer_restrictions.go index 4f449f0c3..fbaefe94c 100644 --- a/rules/peer_restrictions.go +++ b/rules/peer_restrictions.go @@ -53,8 +53,8 @@ func (c *PeerRestrictMgr) Stop() error { // values and config. // // NOTE: This is part of the Manager interface. -func (c *PeerRestrictMgr) NewEnforcer(cfg Config, values Values) (Enforcer, - error) { +func (c *PeerRestrictMgr) NewEnforcer(ctx context.Context, cfg Config, + values Values) (Enforcer, error) { peers, ok := values.(*PeerRestrict) if !ok { @@ -65,7 +65,7 @@ func (c *PeerRestrictMgr) NewEnforcer(cfg Config, values Values) (Enforcer, peerMap := make(map[string]bool, len(peers.DenyList)) for _, peerID := range peers.DenyList { peerMap[peerID] = true - if err := c.maybeUpdateMaps(cfg, peerID); err != nil { + if err := c.maybeUpdateMaps(ctx, cfg, peerID); err != nil { return nil, err } } @@ -112,8 +112,8 @@ func (c *PeerRestrictMgr) EmptyValue() Values { // maybeUpdateMaps updates the managers peer-to-channel and channel-to-peer maps // if the given peer ID is unknown to the manager. -func (c *PeerRestrictMgr) maybeUpdateMaps(cfg peerRestrictCfg, - id string) error { +func (c *PeerRestrictMgr) maybeUpdateMaps(ctx context.Context, + cfg peerRestrictCfg, id string) error { c.mu.Lock() defer c.mu.Unlock() @@ -122,15 +122,17 @@ func (c *PeerRestrictMgr) maybeUpdateMaps(cfg peerRestrictCfg, return nil } - return c.updateMapsUnsafe(cfg) + return c.updateMapsUnsafe(ctx, cfg) } // updateMapsUnsafe updates the manager's peer-to-channel and channel-to-peer // maps. It is not thread safe and so must only be called if the manager's // mutex is being held. -func (c *PeerRestrictMgr) updateMapsUnsafe(cfg peerRestrictCfg) error { +func (c *PeerRestrictMgr) updateMapsUnsafe(ctx context.Context, + cfg peerRestrictCfg) error { + lnd := cfg.GetLndClient() - chans, err := lnd.ListChannels(context.Background(), false, false) + chans, err := lnd.ListChannels(ctx, false, false) if err != nil { return err } @@ -152,8 +154,8 @@ func (c *PeerRestrictMgr) updateMapsUnsafe(cfg peerRestrictCfg) error { return nil } -func (c *PeerRestrictMgr) getPeerFromChanPoint(cfg peerRestrictCfg, - cp string) (string, bool, error) { +func (c *PeerRestrictMgr) getPeerFromChanPoint(ctx context.Context, + cfg peerRestrictCfg, cp string) (string, bool, error) { c.mu.Lock() defer c.mu.Unlock() @@ -163,7 +165,7 @@ func (c *PeerRestrictMgr) getPeerFromChanPoint(cfg peerRestrictCfg, return peer, ok, nil } - err := c.updateMapsUnsafe(cfg) + err := c.updateMapsUnsafe(ctx, cfg) if err != nil { return "", false, err } @@ -295,7 +297,7 @@ func (c *PeerRestrictEnforcer) checkers() map[string]mid.RoundTripChecker { point := fmt.Sprintf("%s:%d", txid, index) peerID, ok, err := c.mgr.getPeerFromChanPoint( - c.cfg, point, + ctx, c.cfg, point, ) if err != nil { return err diff --git a/rules/peer_restrictions_test.go b/rules/peer_restrictions_test.go index cb9502e35..faa3c18d3 100644 --- a/rules/peer_restrictions_test.go +++ b/rules/peer_restrictions_test.go @@ -68,7 +68,7 @@ func TestPeerRestrictCheckRequest(t *testing.T) { }, } - enf, err := mgr.NewEnforcer(cfg, &PeerRestrict{ + enf, err := mgr.NewEnforcer(ctx, cfg, &PeerRestrict{ DenyList: []string{ peerID1, peerID2, }, diff --git a/rules/rate_limit.go b/rules/rate_limit.go index 8f776dcc7..4bff4bbe0 100644 --- a/rules/rate_limit.go +++ b/rules/rate_limit.go @@ -38,8 +38,8 @@ func (r *RateLimitMgr) Stop() error { // and config. // // NOTE: This is part of the Manager interface. -func (r *RateLimitMgr) NewEnforcer(cfg Config, values Values) (Enforcer, - error) { +func (r *RateLimitMgr) NewEnforcer(_ context.Context, cfg Config, + values Values) (Enforcer, error) { limits, ok := values.(*RateLimit) if !ok { From 83a3209a32a6b3702d82853e59a842f6f618b2f5 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 13 Jan 2025 07:04:36 +0200 Subject: [PATCH 5/8] rpcmiddleware: thread context through to Start --- rpcmiddleware/manager.go | 4 ++-- terminal.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/rpcmiddleware/manager.go b/rpcmiddleware/manager.go index 6a43c1733..cdee7eb2d 100644 --- a/rpcmiddleware/manager.go +++ b/rpcmiddleware/manager.go @@ -36,8 +36,8 @@ func NewManager(interceptTimeout time.Duration, } // Start starts the firewall by registering the interceptors with lnd. -func (f *Manager) Start() error { - ctxc, cancel := context.WithCancel(context.Background()) +func (f *Manager) Start(ctx context.Context) error { + ctxc, cancel := context.WithCancel(ctx) f.cancel = cancel for _, i := range f.interceptors { diff --git a/terminal.go b/terminal.go index 2a9cbb316..1006f0f33 100644 --- a/terminal.go +++ b/terminal.go @@ -1104,7 +1104,7 @@ func (g *LightningTerminal) startInternalSubServers(ctx context.Context, g.lndClient.Client, g.errQueue.ChanIn(), mw..., ) - if err = g.middleware.Start(); err != nil { + if err = g.middleware.Start(ctx); err != nil { return err } g.middlewareStarted = true From 1a30bdb83ad83ddff8da556372f5d5fd6f24fd68 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 13 Jan 2025 07:09:04 +0200 Subject: [PATCH 6/8] accounts: thread context through to accounts --- accounts/checkers_test.go | 32 ++++++++++++++++---------------- accounts/service.go | 19 ++++++++++--------- accounts/service_test.go | 5 ++++- terminal.go | 2 +- 4 files changed, 31 insertions(+), 27 deletions(-) diff --git a/accounts/checkers_test.go b/accounts/checkers_test.go index 19e3070e3..964142380 100644 --- a/accounts/checkers_test.go +++ b/accounts/checkers_test.go @@ -499,8 +499,8 @@ func TestSendPaymentCalls(t *testing.T) { func testSendPayment(t *testing.T, uri string) { var ( - parentCtx = context.Background() - zeroFee = &lnrpc.FeeLimit{Limit: &lnrpc.FeeLimit_Fixed{ + ctx = context.Background() + zeroFee = &lnrpc.FeeLimit{Limit: &lnrpc.FeeLimit_Fixed{ Fixed: 0, }} requestID uint64 @@ -520,7 +520,7 @@ func testSendPayment(t *testing.T, uri string) { service, err := NewService(t.TempDir(), errFunc) require.NoError(t, err) - err = service.Start(lndMock, routerMock, chainParams) + err = service.Start(ctx, lndMock, routerMock, chainParams) require.NoError(t, err) assertBalance := func(id AccountID, expectedBalance int64) { @@ -533,7 +533,7 @@ func testSendPayment(t *testing.T, uri string) { // This should error because there is no account in the context. err = service.checkers.checkIncomingRequest( - parentCtx, uri, &lnrpc.SendRequest{}, + ctx, uri, &lnrpc.SendRequest{}, ) require.ErrorContains(t, err, "no account found in context") @@ -543,7 +543,7 @@ func testSendPayment(t *testing.T, uri string) { ) require.NoError(t, err) - ctxWithAcct := AddAccountToContext(parentCtx, acct) + ctxWithAcct := AddAccountToContext(ctx, acct) // This should error because there is no request ID in the context. err = service.checkers.checkIncomingRequest( @@ -552,7 +552,7 @@ func testSendPayment(t *testing.T, uri string) { require.ErrorContains(t, err, "no request ID found in context") reqID1 := nextRequestID() - ctx := AddRequestIDToContext(ctxWithAcct, reqID1) + ctx = AddRequestIDToContext(ctxWithAcct, reqID1) // This should error because no payment hash is provided. err = service.checkers.checkIncomingRequest( @@ -698,7 +698,7 @@ func testSendPayment(t *testing.T, uri string) { func TestSendPaymentV2(t *testing.T) { var ( uri = "/routerrpc.Router/SendPaymentV2" - parentCtx = context.Background() + ctx = context.Background() requestID uint64 ) @@ -716,7 +716,7 @@ func TestSendPaymentV2(t *testing.T) { service, err := NewService(t.TempDir(), errFunc) require.NoError(t, err) - err = service.Start(lndMock, routerMock, chainParams) + err = service.Start(ctx, lndMock, routerMock, chainParams) require.NoError(t, err) assertBalance := func(id AccountID, expectedBalance int64) { @@ -729,7 +729,7 @@ func TestSendPaymentV2(t *testing.T) { // This should error because there is no account in the context. err = service.checkers.checkIncomingRequest( - parentCtx, uri, &routerrpc.SendPaymentRequest{}, + ctx, uri, &routerrpc.SendPaymentRequest{}, ) require.ErrorContains(t, err, "no account found in context") @@ -739,7 +739,7 @@ func TestSendPaymentV2(t *testing.T) { ) require.NoError(t, err) - ctxWithAcct := AddAccountToContext(parentCtx, acct) + ctxWithAcct := AddAccountToContext(ctx, acct) // This should error because there is no request ID in the context. err = service.checkers.checkIncomingRequest( @@ -748,7 +748,7 @@ func TestSendPaymentV2(t *testing.T) { require.ErrorContains(t, err, "no request ID found in context") reqID1 := nextRequestID() - ctx := AddRequestIDToContext(ctxWithAcct, reqID1) + ctx = AddRequestIDToContext(ctxWithAcct, reqID1) // This should error because no payment hash is provided. err = service.checkers.checkIncomingRequest( @@ -885,7 +885,7 @@ func TestSendPaymentV2(t *testing.T) { func TestSendToRouteV2(t *testing.T) { var ( uri = "/routerrpc.Router/SendToRouteV2" - parentCtx = context.Background() + ctx = context.Background() requestID uint64 ) @@ -903,7 +903,7 @@ func TestSendToRouteV2(t *testing.T) { service, err := NewService(t.TempDir(), errFunc) require.NoError(t, err) - err = service.Start(lndMock, routerMock, chainParams) + err = service.Start(ctx, lndMock, routerMock, chainParams) require.NoError(t, err) assertBalance := func(id AccountID, expectedBalance int64) { @@ -916,7 +916,7 @@ func TestSendToRouteV2(t *testing.T) { // This should error because there is no account in the context. err = service.checkers.checkIncomingRequest( - parentCtx, uri, &routerrpc.SendToRouteRequest{}, + ctx, uri, &routerrpc.SendToRouteRequest{}, ) require.ErrorContains(t, err, "no account found in context") @@ -926,7 +926,7 @@ func TestSendToRouteV2(t *testing.T) { ) require.NoError(t, err) - ctxWithAcct := AddAccountToContext(parentCtx, acct) + ctxWithAcct := AddAccountToContext(ctx, acct) // This should error because there is no request ID in the context. err = service.checkers.checkIncomingRequest( @@ -935,7 +935,7 @@ func TestSendToRouteV2(t *testing.T) { require.ErrorContains(t, err, "no request ID found in context") reqID1 := nextRequestID() - ctx := AddRequestIDToContext(ctxWithAcct, reqID1) + ctx = AddRequestIDToContext(ctxWithAcct, reqID1) // This should error because no payment hash is provided. err = service.checkers.checkIncomingRequest( diff --git a/accounts/service.go b/accounts/service.go index 14f49decb..5db9ae872 100644 --- a/accounts/service.go +++ b/accounts/service.go @@ -9,6 +9,7 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/lightninglabs/lndclient" + "github.com/lightninglabs/taproot-assets/fn" "github.com/lightningnetwork/lnd/channeldb" invpkg "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lnrpc" @@ -55,7 +56,7 @@ type InterceptorService struct { routerClient lndclient.RouterClient mainCtx context.Context - contextCancel context.CancelFunc + contextCancel fn.Option[context.CancelFunc] requestMtx sync.Mutex checkers *AccountChecker @@ -85,12 +86,8 @@ func NewService(dir string, return nil, err } - mainCtx, contextCancel := context.WithCancel(context.Background()) - return &InterceptorService{ store: accountStore, - mainCtx: mainCtx, - contextCancel: contextCancel, invoiceToAccount: make(map[lntypes.Hash]AccountID), pendingPayments: make(map[lntypes.Hash]*trackedPayment), requestValuesStore: newRequestValuesStore(), @@ -101,9 +98,14 @@ func NewService(dir string, } // Start starts the account service and its interceptor capability. -func (s *InterceptorService) Start(lightningClient lndclient.LightningClient, +func (s *InterceptorService) Start(ctx context.Context, + lightningClient lndclient.LightningClient, routerClient lndclient.RouterClient, params *chaincfg.Params) error { + mainCtx, contextCancel := context.WithCancel(ctx) + s.mainCtx = mainCtx + s.contextCancel = fn.Some(contextCancel) + s.routerClient = routerClient s.checkers = NewAccountChecker(s, params) @@ -180,7 +182,7 @@ func (s *InterceptorService) Start(lightningClient lndclient.LightningClient, s.wg.Add(1) go func() { defer s.wg.Done() - defer s.contextCancel() + defer contextCancel() for { select { @@ -235,9 +237,8 @@ func (s *InterceptorService) Stop() error { s.requestMtx.Lock() defer s.requestMtx.Unlock() - s.contextCancel() + s.contextCancel.WhenSome(func(fn context.CancelFunc) { fn() }) close(s.quit) - s.wg.Wait() return s.store.Close() diff --git a/accounts/service_test.go b/accounts/service_test.go index 2583f612e..b38b119a4 100644 --- a/accounts/service_test.go +++ b/accounts/service_test.go @@ -838,7 +838,10 @@ func TestAccountService(t *testing.T) { } // Any errors during startup expected? - err = service.Start(lndMock, routerMock, chainParams) + err = service.Start( + context.Background(), lndMock, routerMock, + chainParams, + ) if tc.startupErr != "" { require.ErrorContains(tt, err, tc.startupErr) diff --git a/terminal.go b/terminal.go index 1006f0f33..d6282bb97 100644 --- a/terminal.go +++ b/terminal.go @@ -1042,7 +1042,7 @@ func (g *LightningTerminal) startInternalSubServers(ctx context.Context, log.Infof("Starting LiT account service") if !g.cfg.Accounts.Disable { err = g.accountService.Start( - g.lndClient.Client, g.lndClient.Router, + ctx, g.lndClient.Client, g.lndClient.Router, g.lndClient.ChainParams, ) if err != nil { From 0407505e6c888367a33b12171362755479a590f6 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 13 Jan 2025 07:12:38 +0200 Subject: [PATCH 7/8] autopilotserver: thread context through --- autopilotserver/client.go | 24 ++++++++++++++---------- autopilotserver/client_test.go | 2 +- autopilotserver/interface.go | 2 +- terminal.go | 3 ++- 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/autopilotserver/client.go b/autopilotserver/client.go index 8abf1cce6..cfe56d97b 100644 --- a/autopilotserver/client.go +++ b/autopilotserver/client.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/lightninglabs/lightning-terminal/autopilotserverrpc" + "github.com/lightninglabs/taproot-assets/fn" "github.com/lightningnetwork/lnd/tor" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -88,8 +89,9 @@ type Client struct { featurePerms *featurePerms - quit chan struct{} - wg sync.WaitGroup + quit chan struct{} + wg sync.WaitGroup + cancel fn.Option[context.CancelFunc] } type session struct { @@ -124,16 +126,19 @@ func NewClient(cfg *Config) (Autopilot, error) { } // Start kicks off all the goroutines required by the Client. -func (c *Client) Start(opts ...func(cfg *Config)) error { +func (c *Client) Start(ctx context.Context, opts ...func(cfg *Config)) error { var startErr error c.start.Do(func() { log.Infof("Starting Autopilot Client") + ctx, cancel := context.WithCancel(ctx) + c.cancel = fn.Some(cancel) + for _, o := range opts { o(c.cfg) } - version, err := c.getMinVersion(context.Background()) + version, err := c.getMinVersion(ctx) if err != nil { startErr = err return @@ -154,8 +159,8 @@ func (c *Client) Start(opts ...func(cfg *Config)) error { } c.wg.Add(2) - go c.activateSessionsForever() - go c.updateFeaturePermsForever() + go c.activateSessionsForever(ctx) + go c.updateFeaturePermsForever(ctx) }) return startErr @@ -164,6 +169,7 @@ func (c *Client) Start(opts ...func(cfg *Config)) error { // Stop cleans up any resources or goroutines managed by the Client. func (c *Client) Stop() { c.stop.Do(func() { + c.cancel.WhenSome(func(fn context.CancelFunc) { fn() }) close(c.quit) c.wg.Wait() }) @@ -222,10 +228,9 @@ func (c *Client) SessionRevoked(ctx context.Context, pubKey *btcec.PublicKey) { // activateSessionsForever periodically ensures that each of our active // autopilot sessions are known by the autopilot to be active. -func (c *Client) activateSessionsForever() { +func (c *Client) activateSessionsForever(ctx context.Context) { defer c.wg.Done() - ctx := context.Background() ticker := time.NewTicker(c.cfg.PingCadence) defer ticker.Stop() @@ -273,10 +278,9 @@ func (c *Client) activateSessionsForever() { // feature permissions list. // // NOTE: this MUST be called in a goroutine. -func (c *Client) updateFeaturePermsForever() { +func (c *Client) updateFeaturePermsForever(ctx context.Context) { defer c.wg.Done() - ctx := context.Background() ticker := time.NewTicker(time.Second) defer ticker.Stop() diff --git a/autopilotserver/client_test.go b/autopilotserver/client_test.go index 3dcb07727..d46390079 100644 --- a/autopilotserver/client_test.go +++ b/autopilotserver/client_test.go @@ -32,7 +32,7 @@ func TestAutopilotClient(t *testing.T) { PingCadence: time.Second, }) require.NoError(t, err) - require.NoError(t, client.Start()) + require.NoError(t, client.Start(ctx)) t.Cleanup(client.Stop) privKey, err := btcec.NewPrivateKey() diff --git a/autopilotserver/interface.go b/autopilotserver/interface.go index 5a98d8ad9..914747ba0 100644 --- a/autopilotserver/interface.go +++ b/autopilotserver/interface.go @@ -45,7 +45,7 @@ type Autopilot interface { SessionRevoked(ctx context.Context, key *btcec.PublicKey) // Start kicks off the goroutines of the client. - Start(opts ...func(cfg *Config)) error + Start(ctx context.Context, opts ...func(cfg *Config)) error // Stop cleans up any resources held by the client. Stop() diff --git a/terminal.go b/terminal.go index d6282bb97..c92827ad4 100644 --- a/terminal.go +++ b/terminal.go @@ -1005,7 +1005,8 @@ func (g *LightningTerminal) startInternalSubServers(ctx context.Context, } } - if err = g.autopilotClient.Start(withLndVersion); err != nil { + err = g.autopilotClient.Start(ctx, withLndVersion) + if err != nil { return fmt.Errorf("could not start the autopilot "+ "client: %v", err) } From 76250ca835fba4d6f2d45c4fb293dd2428408973 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 13 Jan 2025 09:07:30 +0200 Subject: [PATCH 8/8] litcli: use cancellable contexts In this commit, we use the signal.Interceptor to cancel the contexts we use for our CLI calls. --- cmd/litcli/accounts.go | 71 +++++++++++---------- cmd/litcli/actions.go | 37 ++++++----- cmd/litcli/autopilot.go | 61 +++++++++--------- cmd/litcli/ln.go | 128 ++++++++++++++++++++------------------ cmd/litcli/main.go | 28 +++++++-- cmd/litcli/privacy_map.go | 41 ++++++------ cmd/litcli/proxy.go | 35 +++++------ cmd/litcli/sessions.go | 41 ++++++------ cmd/litcli/status.go | 12 ++-- 9 files changed, 236 insertions(+), 218 deletions(-) diff --git a/cmd/litcli/accounts.go b/cmd/litcli/accounts.go index ab36a0cd4..3e5468af3 100644 --- a/cmd/litcli/accounts.go +++ b/cmd/litcli/accounts.go @@ -1,7 +1,6 @@ package main import ( - "context" "encoding/hex" "fmt" "os" @@ -75,9 +74,9 @@ spend that amount.`, Action: createAccount, } -func createAccount(ctx *cli.Context) error { - ctxb := context.Background() - clientConn, cleanup, err := connectClient(ctx, false) +func createAccount(cli *cli.Context) error { + ctx := getContext() + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } @@ -88,11 +87,11 @@ func createAccount(ctx *cli.Context) error { initialBalance uint64 expirationDate int64 ) - args := ctx.Args() + args := cli.Args() switch { - case ctx.IsSet("balance"): - initialBalance = ctx.Uint64("balance") + case cli.IsSet("balance"): + initialBalance = cli.Uint64("balance") case args.Present(): initialBalance, err = strconv.ParseUint(args.First(), 10, 64) if err != nil { @@ -102,8 +101,8 @@ func createAccount(ctx *cli.Context) error { } switch { - case ctx.IsSet("expiration_date"): - expirationDate = ctx.Int64("expiration_date") + case cli.IsSet("expiration_date"): + expirationDate = cli.Int64("expiration_date") case args.Present(): expirationDate, err = strconv.ParseInt(args.First(), 10, 64) if err != nil { @@ -117,9 +116,9 @@ func createAccount(ctx *cli.Context) error { req := &litrpc.CreateAccountRequest{ AccountBalance: initialBalance, ExpirationDate: expirationDate, - Label: ctx.String(labelName), + Label: cli.String(labelName), } - resp, err := client.CreateAccount(ctxb, req) + resp, err := client.CreateAccount(ctx, req) if err != nil { return err } @@ -128,8 +127,8 @@ func createAccount(ctx *cli.Context) error { // User requested to store the newly baked account macaroon to a file // in addition to printing it to the console. - if ctx.IsSet("save_to") { - fileName := lncfg.CleanAndExpandPath(ctx.String("save_to")) + if cli.IsSet("save_to") { + fileName := lncfg.CleanAndExpandPath(cli.String("save_to")) err := os.WriteFile(fileName, resp.Macaroon, 0644) if err != nil { return fmt.Errorf("error writing account macaroon "+ @@ -176,16 +175,16 @@ var updateAccountCommand = cli.Command{ Action: updateAccount, } -func updateAccount(ctx *cli.Context) error { - ctxb := context.Background() - clientConn, cleanup, err := connectClient(ctx, false) +func updateAccount(cli *cli.Context) error { + ctx := getContext() + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } defer cleanup() client := litrpc.NewAccountsClient(clientConn) - id, label, args, err := parseIDOrLabel(ctx) + id, label, args, err := parseIDOrLabel(cli) if err != nil { return err } @@ -195,8 +194,8 @@ func updateAccount(ctx *cli.Context) error { expirationDate int64 ) switch { - case ctx.IsSet("new_balance"): - newBalance = ctx.Int64("new_balance") + case cli.IsSet("new_balance"): + newBalance = cli.Int64("new_balance") case args.Present(): newBalance, err = strconv.ParseInt(args.First(), 10, 64) if err != nil { @@ -206,8 +205,8 @@ func updateAccount(ctx *cli.Context) error { } switch { - case ctx.IsSet("new_expiration_date"): - expirationDate = ctx.Int64("new_expiration_date") + case cli.IsSet("new_expiration_date"): + expirationDate = cli.Int64("new_expiration_date") case args.Present(): expirationDate, err = strconv.ParseInt(args.First(), 10, 64) if err != nil { @@ -224,7 +223,7 @@ func updateAccount(ctx *cli.Context) error { AccountBalance: newBalance, ExpirationDate: expirationDate, } - resp, err := client.UpdateAccount(ctxb, req) + resp, err := client.UpdateAccount(ctx, req) if err != nil { return err } @@ -242,9 +241,9 @@ var listAccountsCommand = cli.Command{ Action: listAccounts, } -func listAccounts(ctx *cli.Context) error { - ctxb := context.Background() - clientConn, cleanup, err := connectClient(ctx, false) +func listAccounts(cli *cli.Context) error { + ctx := getContext() + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } @@ -252,7 +251,7 @@ func listAccounts(ctx *cli.Context) error { client := litrpc.NewAccountsClient(clientConn) req := &litrpc.ListAccountsRequest{} - resp, err := client.ListAccounts(ctxb, req) + resp, err := client.ListAccounts(ctx, req) if err != nil { return err } @@ -281,16 +280,16 @@ var accountInfoCommand = cli.Command{ Action: accountInfo, } -func accountInfo(ctx *cli.Context) error { - ctxb := context.Background() - clientConn, cleanup, err := connectClient(ctx, false) +func accountInfo(cli *cli.Context) error { + ctx := getContext() + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } defer cleanup() client := litrpc.NewAccountsClient(clientConn) - id, label, _, err := parseIDOrLabel(ctx) + id, label, _, err := parseIDOrLabel(cli) if err != nil { return err } @@ -299,7 +298,7 @@ func accountInfo(ctx *cli.Context) error { Id: id, Label: label, } - resp, err := client.AccountInfo(ctxb, req) + resp, err := client.AccountInfo(ctx, req) if err != nil { return err } @@ -327,16 +326,16 @@ var removeAccountCommand = cli.Command{ Action: removeAccount, } -func removeAccount(ctx *cli.Context) error { - ctxb := context.Background() - clientConn, cleanup, err := connectClient(ctx, false) +func removeAccount(cli *cli.Context) error { + ctx := getContext() + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } defer cleanup() client := litrpc.NewAccountsClient(clientConn) - id, label, _, err := parseIDOrLabel(ctx) + id, label, _, err := parseIDOrLabel(cli) if err != nil { return err } @@ -345,7 +344,7 @@ func removeAccount(ctx *cli.Context) error { Id: id, Label: label, } - _, err = client.RemoveAccount(ctxb, req) + _, err = client.RemoveAccount(ctx, req) return err } diff --git a/cmd/litcli/actions.go b/cmd/litcli/actions.go index 8e757acd5..2d37e7e52 100644 --- a/cmd/litcli/actions.go +++ b/cmd/litcli/actions.go @@ -1,7 +1,6 @@ package main import ( - "context" "encoding/hex" "fmt" @@ -98,49 +97,49 @@ var listActionsCommand = cli.Command{ }, } -func listActions(ctx *cli.Context) error { - ctxb := context.Background() - clientConn, cleanup, err := connectClient(ctx, false) +func listActions(cli *cli.Context) error { + ctx := getContext() + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } defer cleanup() client := litrpc.NewFirewallClient(clientConn) - state, err := parseActionState(ctx.String("state")) + state, err := parseActionState(cli.String("state")) if err != nil { return err } var sessionID []byte - if ctx.String("session_id") != "" { - sessionID, err = hex.DecodeString(ctx.String("session_id")) + if cli.String("session_id") != "" { + sessionID, err = hex.DecodeString(cli.String("session_id")) if err != nil { return err } } var groupID []byte - if ctx.String("group_id") != "" { - groupID, err = hex.DecodeString(ctx.String("group_id")) + if cli.String("group_id") != "" { + groupID, err = hex.DecodeString(cli.String("group_id")) if err != nil { return err } } resp, err := client.ListActions( - ctxb, &litrpc.ListActionsRequest{ + ctx, &litrpc.ListActionsRequest{ SessionId: sessionID, - FeatureName: ctx.String("feature"), - ActorName: ctx.String("actor"), - MethodName: ctx.String("method"), + FeatureName: cli.String("feature"), + ActorName: cli.String("actor"), + MethodName: cli.String("method"), State: state, - IndexOffset: ctx.Uint64("index_offset"), - MaxNumActions: ctx.Uint64("max_num_actions"), - Reversed: !ctx.Bool("oldest_first"), - CountTotal: ctx.Bool("count_total"), - StartTimestamp: ctx.Uint64("start_timestamp"), - EndTimestamp: ctx.Uint64("end_timestamp"), + IndexOffset: cli.Uint64("index_offset"), + MaxNumActions: cli.Uint64("max_num_actions"), + Reversed: !cli.Bool("oldest_first"), + CountTotal: cli.Bool("count_total"), + StartTimestamp: cli.Uint64("start_timestamp"), + EndTimestamp: cli.Uint64("end_timestamp"), GroupId: groupID, }, ) diff --git a/cmd/litcli/autopilot.go b/cmd/litcli/autopilot.go index 025ea2a9b..1c5b552bf 100644 --- a/cmd/litcli/autopilot.go +++ b/cmd/litcli/autopilot.go @@ -1,7 +1,6 @@ package main import ( - "context" "encoding/hex" "encoding/json" "fmt" @@ -180,22 +179,22 @@ var listAutopilotSessionsCmd = cli.Command{ Action: listAutopilotSessions, } -func revokeAutopilotSession(ctx *cli.Context) error { - ctxb := context.Background() - clientConn, cleanup, err := connectClient(ctx, false) +func revokeAutopilotSession(cli *cli.Context) error { + ctx := getContext() + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } defer cleanup() client := litrpc.NewAutopilotClient(clientConn) - pubkey, err := hex.DecodeString(ctx.String("localpubkey")) + pubkey, err := hex.DecodeString(cli.String("localpubkey")) if err != nil { return err } resp, err := client.RevokeAutopilotSession( - ctxb, &litrpc.RevokeAutopilotSessionRequest{ + ctx, &litrpc.RevokeAutopilotSessionRequest{ LocalPublicKey: pubkey, }, ) @@ -208,9 +207,9 @@ func revokeAutopilotSession(ctx *cli.Context) error { return nil } -func listAutopilotSessions(ctx *cli.Context) error { - ctxb := context.Background() - clientConn, cleanup, err := connectClient(ctx, false) +func listAutopilotSessions(cli *cli.Context) error { + ctx := getContext() + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } @@ -218,7 +217,7 @@ func listAutopilotSessions(ctx *cli.Context) error { client := litrpc.NewAutopilotClient(clientConn) resp, err := client.ListAutopilotSessions( - ctxb, &litrpc.ListAutopilotSessionsRequest{}, + ctx, &litrpc.ListAutopilotSessionsRequest{}, ) if err != nil { return err @@ -229,9 +228,9 @@ func listAutopilotSessions(ctx *cli.Context) error { return nil } -func listFeatures(ctx *cli.Context) error { - ctxb := context.Background() - clientConn, cleanup, err := connectClient(ctx, false) +func listFeatures(cli *cli.Context) error { + ctx := getContext() + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } @@ -239,7 +238,7 @@ func listFeatures(ctx *cli.Context) error { client := litrpc.NewAutopilotClient(clientConn) resp, err := client.ListAutopilotFeatures( - ctxb, &litrpc.ListAutopilotFeaturesRequest{}, + ctx, &litrpc.ListAutopilotFeaturesRequest{}, ) if err != nil { return err @@ -250,19 +249,19 @@ func listFeatures(ctx *cli.Context) error { return nil } -func initAutopilotSession(ctx *cli.Context) error { - sessionLength := time.Second * time.Duration(ctx.Uint64("expiry")) +func initAutopilotSession(cli *cli.Context) error { + sessionLength := time.Second * time.Duration(cli.Uint64("expiry")) sessionExpiry := time.Now().Add(sessionLength).Unix() - ctxb := context.Background() - clientConn, cleanup, err := connectClient(ctx, false) + ctx := getContext() + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } defer cleanup() client := litrpc.NewAutopilotClient(clientConn) - features := ctx.StringSlice("feature") + features := cli.StringSlice("feature") // Check that the user only sets unique features. fs := make(map[string]struct{}) @@ -277,14 +276,14 @@ func initAutopilotSession(ctx *cli.Context) error { // Check that the user did not set multiple restrict lists. var chanRestrictList, peerRestrictList string - channelRestrictSlice := ctx.StringSlice("channel-restrict-list") + channelRestrictSlice := cli.StringSlice("channel-restrict-list") if len(channelRestrictSlice) > 1 { return fmt.Errorf("channel-restrict-list can only be used once") } else if len(channelRestrictSlice) == 1 { chanRestrictList = channelRestrictSlice[0] } - peerRestrictSlice := ctx.StringSlice("peer-restrict-list") + peerRestrictSlice := cli.StringSlice("peer-restrict-list") if len(peerRestrictSlice) > 1 { return fmt.Errorf("peer-restrict-list can only be used once") } else if len(peerRestrictSlice) == 1 { @@ -293,7 +292,7 @@ func initAutopilotSession(ctx *cli.Context) error { // rulesMap stores the rules per each feature. rulesMap := make(map[string]*litrpc.RulesMap) - rulesFlags := ctx.StringSlice("feature-rules") + rulesFlags := cli.StringSlice("feature-rules") // For legacy flags, we allow setting the channel and peer restrict // lists when only a single feature is added. @@ -379,7 +378,7 @@ func initAutopilotSession(ctx *cli.Context) error { } } - configs := ctx.StringSlice("feature-config") + configs := cli.StringSlice("feature-config") if len(configs) > 0 && len(features) != len(configs) { return fmt.Errorf("number of features (%v) and configurations "+ "(%v) must match", len(features), len(configs)) @@ -420,8 +419,8 @@ func initAutopilotSession(ctx *cli.Context) error { } var groupID []byte - if ctx.IsSet("group_id") { - groupID, err = hex.DecodeString(ctx.String("group_id")) + if cli.IsSet("group_id") { + groupID, err = hex.DecodeString(cli.String("group_id")) if err != nil { return err } @@ -429,10 +428,10 @@ func initAutopilotSession(ctx *cli.Context) error { var privacyFlags uint64 var privacyFlagsSet bool - if ctx.IsSet("privacy-flags") { + if cli.IsSet("privacy-flags") { privacyFlagsSet = true - flags, err := session.Parse(ctx.String("privacy-flags")) + flags, err := session.Parse(cli.String("privacy-flags")) if err != nil { return err } @@ -441,11 +440,11 @@ func initAutopilotSession(ctx *cli.Context) error { } resp, err := client.AddAutopilotSession( - ctxb, &litrpc.AddAutopilotSessionRequest{ - Label: ctx.String("label"), + ctx, &litrpc.AddAutopilotSessionRequest{ + Label: cli.String("label"), ExpiryTimestampSeconds: uint64(sessionExpiry), - MailboxServerAddr: ctx.String("mailboxserveraddr"), - DevServer: ctx.Bool("devserver"), + MailboxServerAddr: cli.String("mailboxserveraddr"), + DevServer: cli.Bool("devserver"), Features: featureMap, LinkedGroupId: groupID, PrivacyFlags: privacyFlags, diff --git a/cmd/litcli/ln.go b/cmd/litcli/ln.go index 83ab07074..e4494061d 100644 --- a/cmd/litcli/ln.go +++ b/cmd/litcli/ln.go @@ -91,17 +91,17 @@ var fundChannelCommand = cli.Command{ } func fundChannel(c *cli.Context) error { - tapdConn, cleanup, err := connectSuperMacClient(c) + ctx := getContext() + tapdConn, cleanup, err := connectSuperMacClient(ctx, c) if err != nil { return fmt.Errorf("error creating tapd connection: %w", err) } defer cleanup() - ctxb := context.Background() tapdClient := taprpc.NewTaprootAssetsClient(tapdConn) tchrpcClient := tchrpc.NewTaprootAssetChannelsClient(tapdConn) - assets, err := tapdClient.ListAssets(ctxb, &taprpc.ListAssetRequest{}) + assets, err := tapdClient.ListAssets(ctx, &taprpc.ListAssetRequest{}) if err != nil { return fmt.Errorf("error fetching assets: %w", err) } @@ -146,7 +146,7 @@ func fundChannel(c *cli.Context) error { } resp, err := tchrpcClient.FundChannel( - ctxb, &tchrpc.FundChannelRequest{ + ctx, &tchrpc.FundChannelRequest{ AssetAmount: requestedAmount, AssetId: assetIDBytes, PeerPubkey: nodePubBytes, @@ -297,41 +297,46 @@ var sendPaymentCommand = cli.Command{ Action: sendPayment, } -func sendPayment(ctx *cli.Context) error { +func sendPayment(cliCtx *cli.Context) error { // Show command help if no arguments provided - if ctx.NArg() == 0 && ctx.NumFlags() == 0 { - _ = cli.ShowCommandHelp(ctx, "sendpayment") + if cliCtx.NArg() == 0 && cliCtx.NumFlags() == 0 { + _ = cli.ShowCommandHelp(cliCtx, "sendpayment") return nil } - lndConn, cleanup, err := connectClient(ctx, false) + lndConn, cleanup, err := connectClient(cliCtx, false) if err != nil { return fmt.Errorf("unable to make rpc conn: %w", err) } defer cleanup() - tapdConn, cleanup, err := connectSuperMacClient(ctx) + // NOTE: we don't use `getContext()` here since it assigns the global + // signal interceptor variable which will then cause + // commands.SendPaymentRequest to error out since it will try to do the + // same. + ctx := context.Background() + tapdConn, cleanup, err := connectSuperMacClient(ctx, cliCtx) if err != nil { return fmt.Errorf("error creating tapd connection: %w", err) } defer cleanup() switch { - case !ctx.IsSet(assetIDFlag.Name): + case !cliCtx.IsSet(assetIDFlag.Name): return fmt.Errorf("the --asset_id flag must be set") - case !ctx.IsSet("keysend"): + case !cliCtx.IsSet("keysend"): return fmt.Errorf("the --keysend flag must be set") - case !ctx.IsSet(assetAmountFlag.Name): + case !cliCtx.IsSet(assetAmountFlag.Name): return fmt.Errorf("--asset_amount must be set") } - assetIDStr := ctx.String(assetIDFlag.Name) + assetIDStr := cliCtx.String(assetIDFlag.Name) assetIDBytes, err := hex.DecodeString(assetIDStr) if err != nil { return fmt.Errorf("unable to decode assetID: %v", err) } - assetAmountToSend := ctx.Uint64(assetAmountFlag.Name) + assetAmountToSend := cliCtx.Uint64(assetAmountFlag.Name) if assetAmountToSend == 0 { return fmt.Errorf("must specify asset amount to send") } @@ -344,8 +349,8 @@ func sendPayment(ctx *cli.Context) error { ) switch { - case ctx.IsSet("dest"): - destNode, err = hex.DecodeString(ctx.String("dest")) + case cliCtx.IsSet("dest"): + destNode, err = hex.DecodeString(cliCtx.String("dest")) default: return fmt.Errorf("destination txid argument missing") } @@ -358,7 +363,7 @@ func sendPayment(ctx *cli.Context) error { "is instead: %v", len(destNode)) } - rfqPeerKey, err := hex.DecodeString(ctx.String(rfqPeerPubKeyFlag.Name)) + rfqPeerKey, err := hex.DecodeString(cliCtx.String(rfqPeerPubKeyFlag.Name)) if err != nil { return fmt.Errorf("unable to decode RFQ peer public key: "+ "%w", err) @@ -373,7 +378,7 @@ func sendPayment(ctx *cli.Context) error { DestCustomRecords: make(map[uint64][]byte), } - if ctx.IsSet("payment_hash") { + if cliCtx.IsSet("payment_hash") { return errors.New("cannot set payment hash when using " + "keysend") } @@ -392,10 +397,10 @@ func sendPayment(ctx *cli.Context) error { rHash = hash[:] req.PaymentHash = rHash - allowOverpay := ctx.Bool(allowOverpayFlag.Name) + allowOverpay := cliCtx.Bool(allowOverpayFlag.Name) return commands.SendPaymentRequest( - ctx, req, lndConn, tapdConn, func(ctx context.Context, + cliCtx, req, lndConn, tapdConn, func(ctx context.Context, payConn grpc.ClientConnInterface, req *routerrpc.SendPaymentRequest) ( commands.PaymentResultStream, error) { @@ -447,21 +452,26 @@ var payInvoiceCommand = cli.Command{ Action: payInvoice, } -func payInvoice(ctx *cli.Context) error { - args := ctx.Args() - ctxb := context.Background() +func payInvoice(cli *cli.Context) error { + args := cli.Args() + + // NOTE: we don't use `getContext()` here since it assigns the global + // signal interceptor variable which will then cause + // commands.SendPaymentRequest to error out since it will try to do the + // same. + ctx := context.Background() var payReq string switch { - case ctx.IsSet("pay_req"): - payReq = ctx.String("pay_req") + case cli.IsSet("pay_req"): + payReq = cli.String("pay_req") case args.Present(): payReq = args.First() default: return fmt.Errorf("pay_req argument missing") } - superMacConn, cleanup, err := connectSuperMacClient(ctx) + superMacConn, cleanup, err := connectSuperMacClient(ctx, cli) if err != nil { return fmt.Errorf("unable to make rpc con: %w", err) } @@ -471,35 +481,35 @@ func payInvoice(ctx *cli.Context) error { lndClient := lnrpc.NewLightningClient(superMacConn) decodeReq := &lnrpc.PayReqString{PayReq: payReq} - decodeResp, err := lndClient.DecodePayReq(ctxb, decodeReq) + decodeResp, err := lndClient.DecodePayReq(ctx, decodeReq) if err != nil { return err } - if !ctx.IsSet(assetIDFlag.Name) { + if !cli.IsSet(assetIDFlag.Name) { return fmt.Errorf("the --asset_id flag must be set") } - assetIDStr := ctx.String(assetIDFlag.Name) + assetIDStr := cli.String(assetIDFlag.Name) assetIDBytes, err := hex.DecodeString(assetIDStr) if err != nil { return fmt.Errorf("unable to decode assetID: %v", err) } - rfqPeerKey, err := hex.DecodeString(ctx.String(rfqPeerPubKeyFlag.Name)) + rfqPeerKey, err := hex.DecodeString(cli.String(rfqPeerPubKeyFlag.Name)) if err != nil { return fmt.Errorf("unable to decode RFQ peer public key: "+ "%w", err) } - allowOverpay := ctx.Bool(allowOverpayFlag.Name) + allowOverpay := cli.Bool(allowOverpayFlag.Name) req := &routerrpc.SendPaymentRequest{ PaymentRequest: commands.StripPrefix(payReq), } return commands.SendPaymentRequest( - ctx, req, superMacConn, superMacConn, func(ctx context.Context, + cli, req, superMacConn, superMacConn, func(ctx context.Context, payConn grpc.ClientConnInterface, req *routerrpc.SendPaymentRequest) ( commands.PaymentResultStream, error) { @@ -559,14 +569,14 @@ var addInvoiceCommand = cli.Command{ Action: addInvoice, } -func addInvoice(ctx *cli.Context) error { - args := ctx.Args() - ctxb := context.Background() +func addInvoice(cli *cli.Context) error { + args := cli.Args() + ctx := getContext() var assetIDStr string switch { - case ctx.IsSet("asset_id"): - assetIDStr = ctx.String("asset_id") + case cli.IsSet("asset_id"): + assetIDStr = cli.String("asset_id") case args.Present(): assetIDStr = args.First() args = args.Tail() @@ -581,8 +591,8 @@ func addInvoice(ctx *cli.Context) error { err error ) switch { - case ctx.IsSet("asset_amount"): - assetAmount = ctx.Uint64("asset_amount") + case cli.IsSet("asset_amount"): + assetAmount = cli.Uint64("asset_amount") case args.Present(): assetAmount, err = strconv.ParseUint(args.First(), 10, 64) if err != nil { @@ -593,21 +603,21 @@ func addInvoice(ctx *cli.Context) error { return fmt.Errorf("asset_amount argument missing") } - if ctx.IsSet("preimage") { - preimage, err = hex.DecodeString(ctx.String("preimage")) + if cli.IsSet("preimage") { + preimage, err = hex.DecodeString(cli.String("preimage")) if err != nil { return fmt.Errorf("unable to parse preimage: %w", err) } } - descHash, err = hex.DecodeString(ctx.String("description_hash")) + descHash, err = hex.DecodeString(cli.String("description_hash")) if err != nil { return fmt.Errorf("unable to parse description_hash: %w", err) } expirySeconds := int64(rfq.DefaultInvoiceExpiry.Seconds()) - if ctx.IsSet("expiry") { - expirySeconds = ctx.Int64("expiry") + if cli.IsSet("expiry") { + expirySeconds = cli.Int64("expiry") } assetIDBytes, err := hex.DecodeString(assetIDStr) @@ -618,31 +628,31 @@ func addInvoice(ctx *cli.Context) error { var assetID asset.ID copy(assetID[:], assetIDBytes) - rfqPeerKey, err := hex.DecodeString(ctx.String(rfqPeerPubKeyFlag.Name)) + rfqPeerKey, err := hex.DecodeString(cli.String(rfqPeerPubKeyFlag.Name)) if err != nil { return fmt.Errorf("unable to decode RFQ peer public key: "+ "%w", err) } - tapdConn, cleanup, err := connectSuperMacClient(ctx) + tapdConn, cleanup, err := connectSuperMacClient(ctx, cli) if err != nil { return fmt.Errorf("error creating tapd connection: %w", err) } defer cleanup() channelsClient := tchrpc.NewTaprootAssetChannelsClient(tapdConn) - resp, err := channelsClient.AddInvoice(ctxb, &tchrpc.AddInvoiceRequest{ + resp, err := channelsClient.AddInvoice(ctx, &tchrpc.AddInvoiceRequest{ AssetId: assetIDBytes, AssetAmount: assetAmount, PeerPubkey: rfqPeerKey, InvoiceRequest: &lnrpc.Invoice{ - Memo: ctx.String("memo"), + Memo: cli.String("memo"), RPreimage: preimage, DescriptionHash: descHash, - FallbackAddr: ctx.String("fallback_addr"), + FallbackAddr: cli.String("fallback_addr"), Expiry: expirySeconds, - Private: ctx.Bool("private"), - IsAmp: ctx.Bool("amp"), + Private: cli.Bool("private"), + IsAmp: cli.Bool("amp"), }, }) if err != nil { @@ -679,32 +689,32 @@ var decodeAssetInvoiceCommand = cli.Command{ Action: decodeAssetInvoice, } -func decodeAssetInvoice(ctx *cli.Context) error { - ctxb := context.Background() +func decodeAssetInvoice(cli *cli.Context) error { + ctx := getContext() switch { - case !ctx.IsSet("pay_req"): + case !cli.IsSet("pay_req"): return fmt.Errorf("pay_req argument missing") - case !ctx.IsSet(assetIDFlag.Name): + case !cli.IsSet(assetIDFlag.Name): return fmt.Errorf("the --asset_id flag must be set") } - payReq := ctx.String("pay_req") + payReq := cli.String("pay_req") - assetIDStr := ctx.String(assetIDFlag.Name) + assetIDStr := cli.String(assetIDFlag.Name) assetIDBytes, err := hex.DecodeString(assetIDStr) if err != nil { return fmt.Errorf("unable to decode assetID: %v", err) } - tapdConn, cleanup, err := connectSuperMacClient(ctx) + tapdConn, cleanup, err := connectSuperMacClient(ctx, cli) if err != nil { return fmt.Errorf("unable to make rpc con: %w", err) } defer cleanup() channelsClient := tchrpc.NewTaprootAssetChannelsClient(tapdConn) - resp, err := channelsClient.DecodeAssetPayReq(ctxb, &tchrpc.AssetPayReq{ + resp, err := channelsClient.DecodeAssetPayReq(ctx, &tchrpc.AssetPayReq{ AssetId: assetIDBytes, PayReqString: payReq, }) diff --git a/cmd/litcli/main.go b/cmd/litcli/main.go index 934e2020e..6143afaf6 100644 --- a/cmd/litcli/main.go +++ b/cmd/litcli/main.go @@ -15,6 +15,7 @@ import ( "github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/macaroons" + "github.com/lightningnetwork/lnd/signal" "github.com/urfave/cli" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -299,19 +300,18 @@ func printRespJSON(resp proto.Message) { // nolint fmt.Println(string(jsonBytes)) } -func connectSuperMacClient(ctx *cli.Context) (grpc.ClientConnInterface, - func(), error) { +func connectSuperMacClient(ctx context.Context, cli *cli.Context) ( + grpc.ClientConnInterface, func(), error) { - litdConn, cleanup, err := connectClient(ctx, false) + litdConn, cleanup, err := connectClient(cli, false) if err != nil { return nil, nil, fmt.Errorf("error connecting client: %w", err) } defer cleanup() - ctxb := context.Background() litClient := litrpc.NewProxyClient(litdConn) macResp, err := litClient.BakeSuperMacaroon( - ctxb, &litrpc.BakeSuperMacaroonRequest{}, + ctx, &litrpc.BakeSuperMacaroonRequest{}, ) if err != nil { return nil, nil, fmt.Errorf("error baking macaroon: %w", err) @@ -322,5 +322,21 @@ func connectSuperMacClient(ctx *cli.Context) (grpc.ClientConnInterface, return nil, nil, fmt.Errorf("error decoding macaroon: %w", err) } - return connectClientWithMac(ctx, macBytes) + return connectClientWithMac(cli, macBytes) +} + +func getContext() context.Context { + shutdownInterceptor, err := signal.Intercept() + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + + ctxc, cancel := context.WithCancel(context.Background()) + go func() { + <-shutdownInterceptor.ShutdownChannel() + cancel() + }() + + return ctxc } diff --git a/cmd/litcli/privacy_map.go b/cmd/litcli/privacy_map.go index 8532d6583..8bcc7f2fe 100644 --- a/cmd/litcli/privacy_map.go +++ b/cmd/litcli/privacy_map.go @@ -1,7 +1,6 @@ package main import ( - "context" "encoding/hex" "fmt" @@ -63,9 +62,9 @@ var privacyMapConvertStrCommand = cli.Command{ }, } -func privacyMapConvertStr(ctx *cli.Context) error { - ctxb := context.Background() - clientConn, cleanup, err := connectClient(ctx, false) +func privacyMapConvertStr(cli *cli.Context) error { + ctx := getContext() + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } @@ -73,13 +72,13 @@ func privacyMapConvertStr(ctx *cli.Context) error { client := litrpc.NewFirewallClient(clientConn) var groupID []byte - if ctx.GlobalIsSet("group_id") { - groupID, err = hex.DecodeString(ctx.GlobalString("group_id")) + if cli.GlobalIsSet("group_id") { + groupID, err = hex.DecodeString(cli.GlobalString("group_id")) if err != nil { return err } - } else if ctx.GlobalIsSet("session_id") { - groupID, err = hex.DecodeString(ctx.GlobalString("session_id")) + } else if cli.GlobalIsSet("session_id") { + groupID, err = hex.DecodeString(cli.GlobalString("session_id")) if err != nil { return err } @@ -88,9 +87,9 @@ func privacyMapConvertStr(ctx *cli.Context) error { } resp, err := client.PrivacyMapConversion( - ctxb, &litrpc.PrivacyMapConversionRequest{ - RealToPseudo: ctx.GlobalBool("realtopseudo"), - Input: ctx.String("input"), + ctx, &litrpc.PrivacyMapConversionRequest{ + RealToPseudo: cli.GlobalBool("realtopseudo"), + Input: cli.String("input"), GroupId: groupID, }, ) @@ -117,9 +116,9 @@ var privacyMapConvertUint64Command = cli.Command{ }, } -func privacyMapConvertUint64(ctx *cli.Context) error { - ctxb := context.Background() - clientConn, cleanup, err := connectClient(ctx, false) +func privacyMapConvertUint64(cli *cli.Context) error { + ctx := getContext() + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } @@ -127,13 +126,13 @@ func privacyMapConvertUint64(ctx *cli.Context) error { client := litrpc.NewFirewallClient(clientConn) var groupID []byte - if ctx.GlobalIsSet("group_id") { - groupID, err = hex.DecodeString(ctx.GlobalString("group_id")) + if cli.GlobalIsSet("group_id") { + groupID, err = hex.DecodeString(cli.GlobalString("group_id")) if err != nil { return err } - } else if ctx.GlobalIsSet("session_id") { - groupID, err = hex.DecodeString(ctx.GlobalString("session_id")) + } else if cli.GlobalIsSet("session_id") { + groupID, err = hex.DecodeString(cli.GlobalString("session_id")) if err != nil { return err } @@ -141,11 +140,11 @@ func privacyMapConvertUint64(ctx *cli.Context) error { return fmt.Errorf("must set group_id") } - input := firewalldb.Uint64ToStr(ctx.Uint64("input")) + input := firewalldb.Uint64ToStr(cli.Uint64("input")) resp, err := client.PrivacyMapConversion( - ctxb, &litrpc.PrivacyMapConversionRequest{ - RealToPseudo: ctx.GlobalBool("realtopseudo"), + ctx, &litrpc.PrivacyMapConversionRequest{ + RealToPseudo: cli.GlobalBool("realtopseudo"), Input: input, GroupId: groupID, }, diff --git a/cmd/litcli/proxy.go b/cmd/litcli/proxy.go index 268de50b8..2ae302ad3 100644 --- a/cmd/litcli/proxy.go +++ b/cmd/litcli/proxy.go @@ -1,7 +1,6 @@ package main import ( - "context" "crypto/rand" "encoding/binary" "encoding/hex" @@ -62,16 +61,16 @@ var litCommands = []cli.Command{ }, } -func getInfo(ctx *cli.Context) error { - clientConn, cleanup, err := connectClient(ctx, false) +func getInfo(cli *cli.Context) error { + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } defer cleanup() client := litrpc.NewProxyClient(clientConn) - ctxb := context.Background() - resp, err := client.GetInfo(ctxb, &litrpc.GetInfoRequest{}) + ctx := getContext() + resp, err := client.GetInfo(ctx, &litrpc.GetInfoRequest{}) if err != nil { return err } @@ -81,16 +80,16 @@ func getInfo(ctx *cli.Context) error { return nil } -func shutdownLit(ctx *cli.Context) error { - clientConn, cleanup, err := connectClient(ctx, false) +func shutdownLit(cli *cli.Context) error { + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } defer cleanup() client := litrpc.NewProxyClient(clientConn) - ctxb := context.Background() - _, err = client.StopDaemon(ctxb, &litrpc.StopDaemonRequest{}) + ctx := getContext() + _, err = client.StopDaemon(ctx, &litrpc.StopDaemonRequest{}) if err != nil { return err } @@ -100,11 +99,11 @@ func shutdownLit(ctx *cli.Context) error { return nil } -func bakeSuperMacaroon(ctx *cli.Context) error { +func bakeSuperMacaroon(cli *cli.Context) error { var suffixBytes [4]byte - if ctx.IsSet("root_key_suffix") { + if cli.IsSet("root_key_suffix") { suffixHex, err := hex.DecodeString( - ctx.String("root_key_suffix"), + cli.String("root_key_suffix"), ) if err != nil { return err @@ -119,18 +118,18 @@ func bakeSuperMacaroon(ctx *cli.Context) error { } suffix := binary.BigEndian.Uint32(suffixBytes[:]) - clientConn, cleanup, err := connectClient(ctx, false) + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } defer cleanup() client := litrpc.NewProxyClient(clientConn) - ctxb := context.Background() + ctx := getContext() resp, err := client.BakeSuperMacaroon( - ctxb, &litrpc.BakeSuperMacaroonRequest{ + ctx, &litrpc.BakeSuperMacaroonRequest{ RootKeyIdSuffix: suffix, - ReadOnly: ctx.Bool("read_only"), + ReadOnly: cli.Bool("read_only"), }, ) if err != nil { @@ -139,8 +138,8 @@ func bakeSuperMacaroon(ctx *cli.Context) error { // If the user specified the optional --save_to parameter, we'll save // the macaroon to that file. - if ctx.IsSet("save_to") { - macSavePath := lncfg.CleanAndExpandPath(ctx.String("save_to")) + if cli.IsSet("save_to") { + macSavePath := lncfg.CleanAndExpandPath(cli.String("save_to")) superMacBytes, err := hex.DecodeString(resp.Macaroon) if err != nil { return err diff --git a/cmd/litcli/sessions.go b/cmd/litcli/sessions.go index c3a7d896e..c21c69c97 100644 --- a/cmd/litcli/sessions.go +++ b/cmd/litcli/sessions.go @@ -1,7 +1,6 @@ package main import ( - "context" "encoding/hex" "fmt" "time" @@ -96,41 +95,41 @@ var addSessionCommand = cli.Command{ }, } -func addSession(ctx *cli.Context) error { - clientConn, cleanup, err := connectClient(ctx, false) +func addSession(cli *cli.Context) error { + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } defer cleanup() client := litrpc.NewSessionsClient(clientConn) - sessTypeStr := ctx.String("type") + sessTypeStr := cli.String("type") sessType, err := parseSessionType(sessTypeStr) if err != nil { return err } var macPerms []*litrpc.MacaroonPermission - for _, uri := range ctx.StringSlice("uri") { + for _, uri := range cli.StringSlice("uri") { macPerms = append(macPerms, &litrpc.MacaroonPermission{ Entity: macaroons.PermissionEntityCustomURI, Action: uri, }) } - sessionLength := time.Second * time.Duration(ctx.Uint64("expiry")) + sessionLength := time.Second * time.Duration(cli.Uint64("expiry")) sessionExpiry := time.Now().Add(sessionLength).Unix() - ctxb := context.Background() + ctx := getContext() resp, err := client.AddSession( - ctxb, &litrpc.AddSessionRequest{ - Label: ctx.String("label"), + ctx, &litrpc.AddSessionRequest{ + Label: cli.String("label"), SessionType: sessType, ExpiryTimestampSeconds: uint64(sessionExpiry), - MailboxServerAddr: ctx.String("mailboxserveraddr"), - DevServer: ctx.Bool("devserver"), + MailboxServerAddr: cli.String("mailboxserveraddr"), + DevServer: cli.Bool("devserver"), MacaroonCustomPermissions: macPerms, - AccountId: ctx.String("account_id"), + AccountId: cli.String("account_id"), }, ) if err != nil { @@ -229,17 +228,17 @@ var sessionStateMap = map[litrpc.SessionState]sessionFilter{ } func listSessions(filter sessionFilter) func(ctx *cli.Context) error { - return func(ctx *cli.Context) error { - clientConn, cleanup, err := connectClient(ctx, false) + return func(cli *cli.Context) error { + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } defer cleanup() client := litrpc.NewSessionsClient(clientConn) - ctxb := context.Background() + ctx := getContext() resp, err := client.ListSessions( - ctxb, &litrpc.ListSessionsRequest{}, + ctx, &litrpc.ListSessionsRequest{}, ) if err != nil { return err @@ -279,22 +278,22 @@ var revokeSessionCommand = cli.Command{ }, } -func revokeSession(ctx *cli.Context) error { - clientConn, cleanup, err := connectClient(ctx, false) +func revokeSession(cli *cli.Context) error { + clientConn, cleanup, err := connectClient(cli, false) if err != nil { return err } defer cleanup() client := litrpc.NewSessionsClient(clientConn) - pubkey, err := hex.DecodeString(ctx.String("localpubkey")) + pubkey, err := hex.DecodeString(cli.String("localpubkey")) if err != nil { return err } - ctxb := context.Background() + ctx := getContext() resp, err := client.RevokeSession( - ctxb, &litrpc.RevokeSessionRequest{ + ctx, &litrpc.RevokeSessionRequest{ LocalPublicKey: pubkey, }, ) diff --git a/cmd/litcli/status.go b/cmd/litcli/status.go index 6100866d3..3ae0c2e40 100644 --- a/cmd/litcli/status.go +++ b/cmd/litcli/status.go @@ -1,8 +1,6 @@ package main import ( - "context" - "github.com/lightninglabs/lightning-terminal/litrpc" "github.com/lightningnetwork/lnd/lnrpc" "github.com/urfave/cli" @@ -18,8 +16,8 @@ var statusCommands = []cli.Command{ }, } -func getStatus(ctx *cli.Context) error { - clientConn, cleanup, err := connectClient(ctx, true) +func getStatus(cli *cli.Context) error { + clientConn, cleanup, err := connectClient(cli, true) if err != nil { return err } @@ -27,9 +25,9 @@ func getStatus(ctx *cli.Context) error { litClient := litrpc.NewStatusClient(clientConn) // Get LiT's status. - ctxb := context.Background() + ctx := getContext() litResp, err := litClient.SubServerStatus( - ctxb, &litrpc.SubServerStatusReq{}, + ctx, &litrpc.SubServerStatusReq{}, ) if err != nil { return err @@ -39,7 +37,7 @@ func getStatus(ctx *cli.Context) error { // Get LND's state. lndClient := lnrpc.NewStateClient(clientConn) - lndResp, err := lndClient.GetState(ctxb, &lnrpc.GetStateRequest{}) + lndResp, err := lndClient.GetState(ctx, &lnrpc.GetStateRequest{}) if err != nil { return err }