diff --git a/modules/core/04-channel/keeper/handshake.go b/modules/core/04-channel/keeper/handshake.go index e4c06243b31..28fb2fc3ef3 100644 --- a/modules/core/04-channel/keeper/handshake.go +++ b/modules/core/04-channel/keeper/handshake.go @@ -152,72 +152,55 @@ func (k Keeper) ChanOpenTry( ) } - // handle multihop case + var counterpartyHops []string if len(connectionHops) > 1 { - kvGenerator := func(mProof *types.MsgMultihopProofs, lastHopConnectionEnd *connectiontypes.ConnectionEnd) (string, []byte, error) { - // check version support - if err := checkVersion(lastHopConnectionEnd, order); err != nil { - return "", nil, err - } - - // channel end storekey of the counterparty channel on the other end of the multihop channel - key := host.ChannelPath(counterparty.PortId, counterparty.ChannelId) - - // connection hops - counterpartyHops, err := mProof.GetCounterpartyConnectionHops(k.cdc, &connectionEnd) - if err != nil { - return "", nil, err - } - - // expectedCounterparty is the counterparty of the counterparty's channel end (i.e self) - expectedCounterparty := types.NewCounterparty(portID, "") - expectedChannel := types.NewChannel(types.INIT, order, expectedCounterparty, counterpartyHops, counterpartyVersion) - - // expected value bytes - bz, err := k.cdc.Marshal(&expectedChannel) - if err != nil { - return "", nil, err - } - - return key, bz, nil + var multihopProof types.MsgMultihopProofs + if err := k.cdc.Unmarshal(proofInit, &multihopProof); err != nil { + return "", nil, err } - if err := k.connectionKeeper.VerifyMultihopMembership( - ctx, connectionEnd, proofHeight, proofInit, - connectionHops, kvGenerator); err != nil { + // get the last hop connection on the other side of the multihop channel + // the last hop connection is the connection end on the chain before the counterparty multihop chain + lastHopConnectionEnd, err := multihopProof.GetLastHopConnectionEnd(k.cdc, connectionEnd) + if err != nil { return "", nil, err } - } else { - // determine counterparty hops - counterpartyHops := []string{connectionEnd.GetCounterparty().GetConnectionID()} - // check version support - if err := checkVersion(&connectionEnd, order); err != nil { + if err := checkVersion(lastHopConnectionEnd, order); err != nil { return "", nil, err } - // expectedCounterparty is the counterparty of the counterparty's channel end (i.e self) - expectedCounterparty := types.NewCounterparty(portID, "") - expectedChannel := types.NewChannel( - types.INIT, order, expectedCounterparty, - counterpartyHops, counterpartyVersion, - ) - - if err := k.connectionKeeper.VerifyChannelState( - ctx, connectionEnd, proofHeight, proofInit, - counterparty.PortId, counterparty.ChannelId, expectedChannel, - ); err != nil { + counterpartyHops, err = multihopProof.GetCounterpartyConnectionHops(k.cdc, &connectionEnd) + if err != nil { return "", nil, err } + } else { + // check version support + if err := checkVersion(&connectionEnd, order); err != nil { + return "", nil, err + } + + counterpartyHops = []string{connectionEnd.GetCounterparty().GetConnectionID()} } - var ( - capKey *capabilitytypes.Capability - err error + expectedCounterparty := types.NewCounterparty(portID, "") + expectedChannel := types.NewChannel( + types.INIT, + order, + expectedCounterparty, + counterpartyHops, + counterpartyVersion, ) - capKey, err = k.scopedKeeper.NewCapability(ctx, host.ChannelCapabilityPath(portID, channelID)) + if err := k.connectionKeeper.VerifyChannelState( + ctx, connectionEnd, proofHeight, proofInit, + counterparty.PortId, counterparty.ChannelId, expectedChannel, + ); err != nil { + return "", nil, err + } + + capKey, err := k.scopedKeeper.NewCapability(ctx, host.ChannelCapabilityPath(portID, channelID)) if err != nil { return "", nil, errorsmod.Wrapf(err, "could not create channel capability for port ID %s and channel ID %s", portID, channelID) } @@ -288,46 +271,39 @@ func (k Keeper) ChanOpenAck( "connection state is not OPEN (got %s)", connectiontypes.State(connectionEnd.GetState()).String(), ) } - // verify multihop proof - if len(channel.ConnectionHops) > 1 { - kvGenerator := func(mProof *types.MsgMultihopProofs, _ *connectiontypes.ConnectionEnd) (string, []byte, error) { - key := host.ChannelPath(channel.Counterparty.PortId, counterpartyChannelID) - - counterpartyHops, err := mProof.GetCounterpartyConnectionHops(k.cdc, &connectionEnd) - if err != nil { - return "", nil, err - } - expectedCounterparty := types.NewCounterparty(portID, channelID) - expectedChannel := types.NewChannel( - types.TRYOPEN, channel.Ordering, expectedCounterparty, - counterpartyHops, counterpartyVersion, - ) - value, err := expectedChannel.Marshal() - if err != nil { - return "", nil, err - } - return key, value, nil + var counterpartyHops []string + if len(channel.ConnectionHops) > 1 { + var ( + err error + multihopProof types.MsgMultihopProofs + ) + if err = k.cdc.Unmarshal(proofTry, &multihopProof); err != nil { + return err } - return k.connectionKeeper.VerifyMultihopMembership( - ctx, connectionEnd, proofHeight, proofTry, - channel.ConnectionHops, kvGenerator) - + counterpartyHops, err = multihopProof.GetCounterpartyConnectionHops(k.cdc, &connectionEnd) + if err != nil { + return err + } } else { - counterpartyHops := []string{connectionEnd.GetCounterparty().GetConnectionID()} - expectedCounterparty := types.NewCounterparty(portID, channelID) - expectedChannel := types.NewChannel( - types.TRYOPEN, channel.Ordering, expectedCounterparty, - counterpartyHops, counterpartyVersion, - ) - - return k.connectionKeeper.VerifyChannelState( - ctx, connectionEnd, proofHeight, proofTry, - channel.Counterparty.PortId, counterpartyChannelID, - expectedChannel, - ) + counterpartyHops = []string{connectionEnd.GetCounterparty().GetConnectionID()} } + + // counterparty of the counterparty channel end (i.e self) + expectedCounterparty := types.NewCounterparty(portID, channelID) + expectedChannel := types.NewChannel( + types.TRYOPEN, + channel.Ordering, + expectedCounterparty, + counterpartyHops, + counterpartyVersion, + ) + + return k.connectionKeeper.VerifyChannelState( + ctx, connectionEnd, proofHeight, proofTry, + channel.Counterparty.PortId, counterpartyChannelID, + expectedChannel) } // WriteOpenAckChannel writes an updated channel state for the successful OpenAck handshake step. @@ -394,43 +370,38 @@ func (k Keeper) ChanOpenConfirm( ) } - // verify multihop proof or standard proof + var counterpartyHops []string if len(channel.ConnectionHops) > 1 { - kvGenerator := func(mProof *types.MsgMultihopProofs, _ *connectiontypes.ConnectionEnd) (string, []byte, error) { - key := host.ChannelPath(channel.Counterparty.PortId, channel.Counterparty.ChannelId) - - counterpartyHops, err := mProof.GetCounterpartyConnectionHops(k.cdc, &connectionEnd) - if err != nil { - return "", nil, err - } - counterparty := types.NewCounterparty(portID, channelID) - expectedChannel := types.NewChannel( - types.OPEN, channel.Ordering, counterparty, - counterpartyHops, channel.Version, - ) - value, err := expectedChannel.Marshal() - if err != nil { - return "", nil, err - } - return key, value, nil + var ( + err error + multihopProof types.MsgMultihopProofs + ) + if err = k.cdc.Unmarshal(proofAck, &multihopProof); err != nil { + return err } - return k.connectionKeeper.VerifyMultihopMembership( - ctx, connectionEnd, proofHeight, proofAck, - channel.ConnectionHops, kvGenerator) - + counterpartyHops, err = multihopProof.GetCounterpartyConnectionHops(k.cdc, &connectionEnd) + if err != nil { + return err + } } else { - counterpartyHops := []string{connectionEnd.GetCounterparty().GetConnectionID()} - counterparty := types.NewCounterparty(portID, channelID) - expectedChannel := types.NewChannel( - types.OPEN, channel.Ordering, counterparty, - counterpartyHops, channel.Version) - - return k.connectionKeeper.VerifyChannelState( - ctx, connectionEnd, proofHeight, proofAck, - channel.Counterparty.PortId, channel.Counterparty.ChannelId, - expectedChannel) + counterpartyHops = []string{connectionEnd.GetCounterparty().GetConnectionID()} } + + // counterparty of the counterparty channel end (i.e self) + expectedCounterparty := types.NewCounterparty(portID, channelID) + expectedChannel := types.NewChannel( + types.OPEN, + channel.Ordering, + expectedCounterparty, + counterpartyHops, + channel.Version, + ) + + return k.connectionKeeper.VerifyChannelState( + ctx, connectionEnd, proofHeight, proofAck, + channel.Counterparty.PortId, channel.Counterparty.ChannelId, + expectedChannel) } // WriteOpenConfirmChannel writes an updated channel state for the successful OpenConfirm handshake step. @@ -548,47 +519,39 @@ func (k Keeper) ChanCloseConfirm( ) } - // verify multihop proof + var counterpartyHops []string if len(channel.ConnectionHops) > 1 { - - kvGenerator := func(mProof *types.MsgMultihopProofs, _ *connectiontypes.ConnectionEnd) (string, []byte, error) { - key := host.ChannelPath(channel.Counterparty.PortId, channel.Counterparty.ChannelId) - - counterpartyHops, err := mProof.GetCounterpartyConnectionHops(k.cdc, &connectionEnd) - if err != nil { - return "", nil, err - } - counterparty := types.NewCounterparty(portID, channelID) - expectedChannel := types.NewChannel( - types.CLOSED, channel.Ordering, counterparty, - counterpartyHops, channel.Version, - ) - value, err := expectedChannel.Marshal() - if err != nil { - return "", nil, err - } - return key, value, nil - } - - if err := k.connectionKeeper.VerifyMultihopMembership( - ctx, connectionEnd, proofHeight, proofInit, - channel.ConnectionHops, kvGenerator); err != nil { + var ( + err error + multihopProof types.MsgMultihopProofs + ) + if err = k.cdc.Unmarshal(proofInit, &multihopProof); err != nil { return err } - } else { - counterpartyHops := []string{connectionEnd.GetCounterparty().GetConnectionID()} - counterparty := types.NewCounterparty(portID, channelID) - expectedChannel := types.NewChannel( - types.CLOSED, channel.Ordering, counterparty, - counterpartyHops, channel.Version) - - if err := k.connectionKeeper.VerifyChannelState( - ctx, connectionEnd, proofHeight, proofInit, - channel.Counterparty.PortId, channel.Counterparty.ChannelId, - expectedChannel); err != nil { + counterpartyHops, err = multihopProof.GetCounterpartyConnectionHops(k.cdc, &connectionEnd) + if err != nil { return err } + } else { + counterpartyHops = []string{connectionEnd.GetCounterparty().GetConnectionID()} + } + + // counterparty of the counterparty channel end (i.e self) + expectedCounterparty := types.NewCounterparty(portID, channelID) + expectedChannel := types.NewChannel( + types.CLOSED, + channel.Ordering, + expectedCounterparty, + counterpartyHops, + channel.Version, + ) + + if err := k.connectionKeeper.VerifyChannelState( + ctx, connectionEnd, proofHeight, proofInit, + channel.Counterparty.PortId, channel.Counterparty.ChannelId, + expectedChannel); err != nil { + return err } k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", channel.State.String(), "new-state", types.CLOSED.String())