diff --git a/modules/core/keeper/msg_server.go b/modules/core/keeper/msg_server.go index 7c65222f7df..b6eb61235ce 100644 --- a/modules/core/keeper/msg_server.go +++ b/modules/core/keeper/msg_server.go @@ -211,7 +211,7 @@ func (k Keeper) ChannelOpenInitUnchecked(goCtx context.Context, msg *channeltype // ChannelOpenTry defines a rpc handler method for MsgChannelOpenTry. // ChannelOpenTry will perform 04-channel checks, route to the application // callback, and write an OpenTry channel into state upon successful execution. -func (k Keeper) ChannelOpenTry(goCtx context.Context, msg *channeltypes.MsgChannelOpenTry) (*channeltypes.MsgChannelOpenTryResponse, error) { +func (k Keeper) ChannelOpenTryUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelOpenTry) (*channeltypes.MsgChannelOpenTryResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) // auto register virtual port if connection is virtual @@ -265,7 +265,7 @@ func (k Keeper) ChannelOpenTry(goCtx context.Context, msg *channeltypes.MsgChann // ChannelOpenAck defines a rpc handler method for MsgChannelOpenAck. // ChannelOpenAck will perform 04-channel checks, route to the application // callback, and write an OpenAck channel into state upon successful execution. -func (k Keeper) ChannelOpenAck(goCtx context.Context, msg *channeltypes.MsgChannelOpenAck) (*channeltypes.MsgChannelOpenAckResponse, error) { +func (k Keeper) ChannelOpenAckUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelOpenAck) (*channeltypes.MsgChannelOpenAckResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) // Lookup module by channel capability @@ -307,7 +307,7 @@ func (k Keeper) ChannelOpenAck(goCtx context.Context, msg *channeltypes.MsgChann // ChannelOpenConfirm defines a rpc handler method for MsgChannelOpenConfirm. // ChannelOpenConfirm will perform 04-channel checks, route to the application // callback, and write an OpenConfirm channel into state upon successful execution. -func (k Keeper) ChannelOpenConfirm(goCtx context.Context, msg *channeltypes.MsgChannelOpenConfirm) (*channeltypes.MsgChannelOpenConfirmResponse, error) { +func (k Keeper) ChannelOpenConfirmUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelOpenConfirm) (*channeltypes.MsgChannelOpenConfirmResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) // Lookup module by channel capability @@ -345,7 +345,7 @@ func (k Keeper) ChannelOpenConfirm(goCtx context.Context, msg *channeltypes.MsgC } // ChannelCloseInit defines a rpc handler method for MsgChannelCloseInit. -func (k Keeper) ChannelCloseInit(goCtx context.Context, msg *channeltypes.MsgChannelCloseInit) (*channeltypes.MsgChannelCloseInitResponse, error) { +func (k Keeper) ChannelCloseInitUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelCloseInit) (*channeltypes.MsgChannelCloseInitResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) // Lookup module by channel capability module, cap, err := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.PortId, msg.ChannelId) @@ -378,7 +378,7 @@ func (k Keeper) ChannelCloseInit(goCtx context.Context, msg *channeltypes.MsgCha } // ChannelCloseConfirm defines a rpc handler method for MsgChannelCloseConfirm. -func (k Keeper) ChannelCloseConfirm(goCtx context.Context, msg *channeltypes.MsgChannelCloseConfirm) (*channeltypes.MsgChannelCloseConfirmResponse, error) { +func (k Keeper) ChannelCloseConfirmUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelCloseConfirm) (*channeltypes.MsgChannelCloseConfirmResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) // Lookup module by channel capability @@ -412,7 +412,7 @@ func (k Keeper) ChannelCloseConfirm(goCtx context.Context, msg *channeltypes.Msg } // ChannelCloseFrozen defines a rpc handler method for MsgChannelCloseFrozen. -func (k Keeper) ChannelCloseFrozen(goCtx context.Context, msg *channeltypes.MsgChannelCloseFrozen) (*channeltypes.MsgChannelCloseFrozenResponse, error) { +func (k Keeper) ChannelCloseFrozenUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelCloseFrozen) (*channeltypes.MsgChannelCloseFrozenResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) // Lookup module by channel capability @@ -447,7 +447,7 @@ func (k Keeper) ChannelCloseFrozen(goCtx context.Context, msg *channeltypes.MsgC } // RecvPacket defines a rpc handler method for MsgRecvPacket. -func (k Keeper) RecvPacket(goCtx context.Context, msg *channeltypes.MsgRecvPacket) (*channeltypes.MsgRecvPacketResponse, error) { +func (k Keeper) RecvPacketUnchecked(goCtx context.Context, msg *channeltypes.MsgRecvPacket) (*channeltypes.MsgRecvPacketResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) relayer, err := sdk.AccAddressFromBech32(msg.Signer) @@ -679,7 +679,7 @@ func (k Keeper) TimeoutOnClose(goCtx context.Context, msg *channeltypes.MsgTimeo } // Acknowledgement defines a rpc handler method for MsgAcknowledgement. -func (k Keeper) Acknowledgement(goCtx context.Context, msg *channeltypes.MsgAcknowledgement) (*channeltypes.MsgAcknowledgementResponse, error) { +func (k Keeper) AcknowledgementUnchecked(goCtx context.Context, msg *channeltypes.MsgAcknowledgement) (*channeltypes.MsgAcknowledgementResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) relayer, err := sdk.AccAddressFromBech32(msg.Signer) diff --git a/modules/core/keeper/msg_server_override.go b/modules/core/keeper/msg_server_override.go index 8e34b94f761..65d479c309a 100644 --- a/modules/core/keeper/msg_server_override.go +++ b/modules/core/keeper/msg_server_override.go @@ -2,6 +2,7 @@ package keeper import ( "context" + "fmt" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" @@ -17,14 +18,118 @@ This files contains tx msg endpoint methods that override the default IBC behavi // ChannelOpenInit defines a rpc handler method for MsgChannelOpenInit. // ChannelOpenInit will perform 04-channel checks, route to the application // callback, and write an OpenInit channel into state upon successful execution. -func (k Keeper) ChannelOpenInit(goCtx context.Context, msg *channeltypes.MsgChannelOpenInit) (*channeltypes.MsgChannelOpenInitResponse, error) { +func (k Keeper) ChannelOpenInit( + goCtx context.Context, + msg *channeltypes.MsgChannelOpenInit, +) (*channeltypes.MsgChannelOpenInitResponse, error) { + err := k.ensureNonVirtualSender(goCtx, msg.Channel, "ChannelOpenInit") + if err != nil { + return nil, err + } + return k.ChannelOpenInitUnchecked(goCtx, msg) +} + +func (k Keeper) ChannelOpenTry(goCtx context.Context, msg *channeltypes.MsgChannelOpenTry) (*channeltypes.MsgChannelOpenTryResponse, error) { + err := k.ensureNonVirtualConnectionsForChannel(goCtx, "ChannelOpenTry", msg.Channel) + if err != nil { + return nil, err + } + return k.ChannelOpenTryUnchecked(goCtx, msg) +} + +func (k Keeper) ChannelOpenAck(goCtx context.Context, msg *channeltypes.MsgChannelOpenAck) (*channeltypes.MsgChannelOpenAckResponse, error) { + err := k.ensureNonVirtualConnections(goCtx, "ChannelOpenAck", msg.PortId, msg.ChannelId) + if err != nil { + return nil, err + } + return k.ChannelOpenAckUnchecked(goCtx, msg) +} + +func (k Keeper) ChannelOpenConfirm(goCtx context.Context, msg *channeltypes.MsgChannelOpenConfirm) (*channeltypes.MsgChannelOpenConfirmResponse, error) { + err := k.ensureNonVirtualConnections(goCtx, "ChannelOpenConfirm", msg.PortId, msg.ChannelId) + if err != nil { + return nil, err + } + return k.ChannelOpenConfirmUnchecked(goCtx, msg) +} + +func (k Keeper) ChannelCloseInit(goCtx context.Context, msg *channeltypes.MsgChannelCloseInit) (*channeltypes.MsgChannelCloseInitResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) + channel, found := k.ChannelKeeper.GetChannel(ctx, msg.PortId, msg.ChannelId) + if !found { + return nil, sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", msg.PortId, msg.ChannelId) + } - // Ensure the first connection is not virtual; because ChannelOpenInit for virtual channel must go through - // VIBC.OpenIBCChannel endpoint - if isVirtual, connEnd := k.ChannelKeeper.IsVirtualConnection(ctx, msg.Channel.ConnectionHops[0]); isVirtual { - return nil, sdkerrors.Wrapf(connectiontypes.ErrInvalidConnection, "ChanelOpenInit can only be invoked directly on a non-virtual connection, connection: %v", connEnd) + err := k.ensureNonVirtualSender(goCtx, channel, "ChannelCloseInit") + if err != nil { + return nil, err } - return k.ChannelOpenInitUnchecked(goCtx, msg) + return k.ChannelCloseInitUnchecked(goCtx, msg) +} + +func (k Keeper) ChannelCloseConfirm(goCtx context.Context, msg *channeltypes.MsgChannelCloseConfirm) (*channeltypes.MsgChannelCloseConfirmResponse, error) { + err := k.ensureNonVirtualConnections(goCtx, "ChannelCloseConfirm", msg.PortId, msg.ChannelId) + if err != nil { + return nil, err + } + return k.ChannelCloseConfirmUnchecked(goCtx, msg) +} + +func (k Keeper) ChannelCloseFrozen(goCtx context.Context, msg *channeltypes.MsgChannelCloseFrozen) (*channeltypes.MsgChannelCloseFrozenResponse, error) { + err := k.ensureNonVirtualConnections(goCtx, "ChannelCloseFrozen", msg.PortId, msg.ChannelId) + if err != nil { + return nil, err + } + return k.ChannelCloseFrozenUnchecked(goCtx, msg) +} + +func (k Keeper) RecvPacket(goCtx context.Context, msg *channeltypes.MsgRecvPacket) (*channeltypes.MsgRecvPacketResponse, error) { + err := k.ensureNonVirtualConnections(goCtx, "RecvPacket", msg.Packet.GetDestPort(), msg.Packet.GetDestChannel()) + if err != nil { + return nil, err + } + return k.RecvPacketUnchecked(goCtx, msg) +} + +func (k Keeper) Acknowledgement(goCtx context.Context, msg *channeltypes.MsgAcknowledgement) (*channeltypes.MsgAcknowledgementResponse, error) { + err := k.ensureNonVirtualConnections(goCtx, "Acknowledgement", msg.Packet.GetSourcePort(), msg.Packet.GetSourceChannel()) + if err != nil { + return nil, err + } + return k.AcknowledgementUnchecked(goCtx, msg) +} + +func (k Keeper) ensureNonVirtualSender(goCtx context.Context, channel channeltypes.Channel, methodName string) error { + ctx := sdk.UnwrapSDKContext(goCtx) + + if isVirtual, connEnd := k.ChannelKeeper.IsVirtualConnection(ctx, channel.ConnectionHops[0]); isVirtual { + return sdkerrors.Wrapf( + connectiontypes.ErrInvalidConnection, + "%s can only be invoked directly on a non-virtual connection, connection: %v", + methodName, connEnd, + ) + } + return nil +} + +func (k Keeper) ensureNonVirtualConnections(goCtx context.Context, methodName, portID, channelID string) error { + ctx := sdk.UnwrapSDKContext(goCtx) + channel, found := k.ChannelKeeper.GetChannel(ctx, portID, channelID) + if !found { + return sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) + } + return k.ensureNonVirtualConnectionsForChannel(goCtx, methodName, channel) +} + +func (k Keeper) ensureNonVirtualConnectionsForChannel(goCtx context.Context, methodName string, channel channeltypes.Channel) error { + ctx := sdk.UnwrapSDKContext(goCtx) + isVirtual := k.ChannelKeeper.IsVirtualEndToVirtualEnd(ctx, channel.ConnectionHops) + if isVirtual { + return sdkerrors.Wrapf( + connectiontypes.ErrInvalidConnection, + fmt.Sprintf("%s can only be invoked directly on non-virtual connections", methodName), + ) + } + return nil }