Skip to content

Commit

Permalink
make the tx channel-upgrade init --unsafe flag safer
Browse files Browse the repository at this point in the history
Signed-off-by: Masanori Yoshida <[email protected]>
  • Loading branch information
siburu committed Oct 8, 2024
1 parent 2b8ff40 commit 75c1603
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 15 deletions.
26 changes: 12 additions & 14 deletions cmd/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,16 +254,9 @@ func channelUpgradeInitCmd(ctx *config.Context) *cobra.Command {
}

// check cp state
if unsafe, err := cmd.Flags().GetBool(flagUnsafe); err != nil {
permitUnsafe, err := cmd.Flags().GetBool(flagUnsafe)
if err != nil {
return err
} else if !unsafe {
if height, err := cp.LatestHeight(); err != nil {
return err
} else if chann, err := cp.QueryChannel(core.NewQueryContext(cmd.Context(), height)); err != nil {
return err
} else if state := chann.Channel.State; state >= chantypes.FLUSHING && state <= chantypes.FLUSHCOMPLETE {
return fmt.Errorf("stop channel upgrade initialization because the counterparty is in %v state", state)
}
}

// get ordering from flags
Expand All @@ -286,11 +279,16 @@ func channelUpgradeInitCmd(ctx *config.Context) *cobra.Command {
return err
}

return core.InitChannelUpgrade(chain, chantypes.UpgradeFields{
Ordering: ordering,
ConnectionHops: connHops,
Version: version,
})
return core.InitChannelUpgrade(
chain,
cp,
chantypes.UpgradeFields{
Ordering: ordering,
ConnectionHops: connHops,
Version: version,
},
permitUnsafe,
)
},
}

Expand Down
32 changes: 31 additions & 1 deletion core/channel-upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,40 @@ func (action UpgradeAction) String() string {
}

// InitChannelUpgrade builds `MsgChannelUpgradeInit` based on the specified UpgradeFields and sends it to the specified chain.
func InitChannelUpgrade(chain *ProvableChain, upgradeFields chantypes.UpgradeFields) error {
func InitChannelUpgrade(chain, cp *ProvableChain, upgradeFields chantypes.UpgradeFields, permitUnsafe bool) error {
logger := GetChannelLogger(chain.Chain)
defer logger.TimeTrack(time.Now(), "InitChannelUpgrade")

if h, err := chain.LatestHeight(); err != nil {
logger.Error("failed to get the latest height", err)
return err
} else if cpH, err := cp.LatestHeight(); err != nil {
logger.Error("failed to get the latest height of the counterparty chain", err)
return err
} else if chann, cpChann, err := QueryChannelPair(
NewQueryContext(context.TODO(), h),
NewQueryContext(context.TODO(), cpH),
chain,
cp,
false,
); err != nil {
logger.Error("failed to query for the channel pair", err)
return err
} else if chann.Channel.State != chantypes.OPEN || cpChann.Channel.State != chantypes.OPEN {
logger = &log.RelayLogger{Logger: logger.With(
"channel_state", chann.Channel.State,
"cp_channel_state", cpChann.Channel.State,
)}

if permitUnsafe {
logger.Info("unsafe channel upgrade is permitted")
} else {
err := errors.New("unsafe channel upgrade initialization")
logger.Error("unsafe channel upgrade is not permitted", err)
return err
}
}

addr, err := chain.GetAddress()
if err != nil {
logger.Error("failed to get address", err)
Expand Down

0 comments on commit 75c1603

Please sign in to comment.