From 04e154b9292375c8ffee7c9412c5c2f5e397c2b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=A1s=20Pernas=20Maradei?= Date: Mon, 11 Mar 2024 15:18:39 +0100 Subject: [PATCH 1/2] wip --- modules/core/keeper/msg_server.go | 2 +- modules/core/keeper/msg_server_override.go | 30 ++++++++++++++++++++-- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/modules/core/keeper/msg_server.go b/modules/core/keeper/msg_server.go index 7c65222f7df..872e33fbcfc 100644 --- a/modules/core/keeper/msg_server.go +++ b/modules/core/keeper/msg_server.go @@ -211,7 +211,7 @@ func (k Keeper) ChannelOpenInitUnchecked(goCtx context.Context, msg *channeltype // ChannelOpenTry defines a rpc handler method for MsgChannelOpenTry. // ChannelOpenTry will perform 04-channel checks, route to the application // callback, and write an OpenTry channel into state upon successful execution. -func (k Keeper) ChannelOpenTry(goCtx context.Context, msg *channeltypes.MsgChannelOpenTry) (*channeltypes.MsgChannelOpenTryResponse, error) { +func (k Keeper) ChannelOpenTryUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelOpenTry) (*channeltypes.MsgChannelOpenTryResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) // auto register virtual port if connection is virtual diff --git a/modules/core/keeper/msg_server_override.go b/modules/core/keeper/msg_server_override.go index 8e34b94f761..a0ee745b5ad 100644 --- a/modules/core/keeper/msg_server_override.go +++ b/modules/core/keeper/msg_server_override.go @@ -17,14 +17,40 @@ This files contains tx msg endpoint methods that override the default IBC behavi // ChannelOpenInit defines a rpc handler method for MsgChannelOpenInit. // ChannelOpenInit will perform 04-channel checks, route to the application // callback, and write an OpenInit channel into state upon successful execution. -func (k Keeper) ChannelOpenInit(goCtx context.Context, msg *channeltypes.MsgChannelOpenInit) (*channeltypes.MsgChannelOpenInitResponse, error) { +func (k Keeper) ChannelOpenInit( + goCtx context.Context, + msg *channeltypes.MsgChannelOpenInit, +) (*channeltypes.MsgChannelOpenInitResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) // Ensure the first connection is not virtual; because ChannelOpenInit for virtual channel must go through // VIBC.OpenIBCChannel endpoint if isVirtual, connEnd := k.ChannelKeeper.IsVirtualConnection(ctx, msg.Channel.ConnectionHops[0]); isVirtual { - return nil, sdkerrors.Wrapf(connectiontypes.ErrInvalidConnection, "ChanelOpenInit can only be invoked directly on a non-virtual connection, connection: %v", connEnd) + return nil, sdkerrors.Wrapf( + connectiontypes.ErrInvalidConnection, + "ChanelOpenInit can only be invoked directly on a non-virtual connection, connection: %v", + connEnd, + ) } return k.ChannelOpenInitUnchecked(goCtx, msg) } + +func (k Keeper) ChannelOpenTry( + goCtx context.Context, + msg *channeltypes.MsgChannelOpenTry, +) (*channeltypes.MsgChannelOpenTryResponse, error) { + ctx := sdk.UnwrapSDKContext(goCtx) + + // Ensure the first connection is not virtual; because ChannelOpenInit for virtual channel must go through + // VIBC.OpenIBCChannel endpoint + if isVirtual, connEnd := k.ChannelKeeper.IsVirtualConnection(ctx, msg.Channel.ConnectionHops[0]); isVirtual { + return nil, sdkerrors.Wrapf( + connectiontypes.ErrInvalidConnection, + "ChanelOpenTry can only be invoked directly on a non-virtual connection, connection: %v", + connEnd, + ) + } + + return k.ChannelOpenTryUnchecked(goCtx, msg) +} From 7fbc4c722cdb42489c9a3c47a6d102f799da4274 Mon Sep 17 00:00:00 2001 From: Alexander Date: Tue, 2 Apr 2024 21:13:23 -0400 Subject: [PATCH 2/2] Add additional verification to the rest of the channel handshake methods --- modules/core/keeper/msg_server.go | 14 +-- modules/core/keeper/msg_server_override.go | 119 +++++++++++++++++---- 2 files changed, 106 insertions(+), 27 deletions(-) diff --git a/modules/core/keeper/msg_server.go b/modules/core/keeper/msg_server.go index 872e33fbcfc..b6eb61235ce 100644 --- a/modules/core/keeper/msg_server.go +++ b/modules/core/keeper/msg_server.go @@ -265,7 +265,7 @@ func (k Keeper) ChannelOpenTryUnchecked(goCtx context.Context, msg *channeltypes // ChannelOpenAck defines a rpc handler method for MsgChannelOpenAck. // ChannelOpenAck will perform 04-channel checks, route to the application // callback, and write an OpenAck channel into state upon successful execution. -func (k Keeper) ChannelOpenAck(goCtx context.Context, msg *channeltypes.MsgChannelOpenAck) (*channeltypes.MsgChannelOpenAckResponse, error) { +func (k Keeper) ChannelOpenAckUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelOpenAck) (*channeltypes.MsgChannelOpenAckResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) // Lookup module by channel capability @@ -307,7 +307,7 @@ func (k Keeper) ChannelOpenAck(goCtx context.Context, msg *channeltypes.MsgChann // ChannelOpenConfirm defines a rpc handler method for MsgChannelOpenConfirm. // ChannelOpenConfirm will perform 04-channel checks, route to the application // callback, and write an OpenConfirm channel into state upon successful execution. -func (k Keeper) ChannelOpenConfirm(goCtx context.Context, msg *channeltypes.MsgChannelOpenConfirm) (*channeltypes.MsgChannelOpenConfirmResponse, error) { +func (k Keeper) ChannelOpenConfirmUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelOpenConfirm) (*channeltypes.MsgChannelOpenConfirmResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) // Lookup module by channel capability @@ -345,7 +345,7 @@ func (k Keeper) ChannelOpenConfirm(goCtx context.Context, msg *channeltypes.MsgC } // ChannelCloseInit defines a rpc handler method for MsgChannelCloseInit. -func (k Keeper) ChannelCloseInit(goCtx context.Context, msg *channeltypes.MsgChannelCloseInit) (*channeltypes.MsgChannelCloseInitResponse, error) { +func (k Keeper) ChannelCloseInitUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelCloseInit) (*channeltypes.MsgChannelCloseInitResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) // Lookup module by channel capability module, cap, err := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.PortId, msg.ChannelId) @@ -378,7 +378,7 @@ func (k Keeper) ChannelCloseInit(goCtx context.Context, msg *channeltypes.MsgCha } // ChannelCloseConfirm defines a rpc handler method for MsgChannelCloseConfirm. -func (k Keeper) ChannelCloseConfirm(goCtx context.Context, msg *channeltypes.MsgChannelCloseConfirm) (*channeltypes.MsgChannelCloseConfirmResponse, error) { +func (k Keeper) ChannelCloseConfirmUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelCloseConfirm) (*channeltypes.MsgChannelCloseConfirmResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) // Lookup module by channel capability @@ -412,7 +412,7 @@ func (k Keeper) ChannelCloseConfirm(goCtx context.Context, msg *channeltypes.Msg } // ChannelCloseFrozen defines a rpc handler method for MsgChannelCloseFrozen. -func (k Keeper) ChannelCloseFrozen(goCtx context.Context, msg *channeltypes.MsgChannelCloseFrozen) (*channeltypes.MsgChannelCloseFrozenResponse, error) { +func (k Keeper) ChannelCloseFrozenUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelCloseFrozen) (*channeltypes.MsgChannelCloseFrozenResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) // Lookup module by channel capability @@ -447,7 +447,7 @@ func (k Keeper) ChannelCloseFrozen(goCtx context.Context, msg *channeltypes.MsgC } // RecvPacket defines a rpc handler method for MsgRecvPacket. -func (k Keeper) RecvPacket(goCtx context.Context, msg *channeltypes.MsgRecvPacket) (*channeltypes.MsgRecvPacketResponse, error) { +func (k Keeper) RecvPacketUnchecked(goCtx context.Context, msg *channeltypes.MsgRecvPacket) (*channeltypes.MsgRecvPacketResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) relayer, err := sdk.AccAddressFromBech32(msg.Signer) @@ -679,7 +679,7 @@ func (k Keeper) TimeoutOnClose(goCtx context.Context, msg *channeltypes.MsgTimeo } // Acknowledgement defines a rpc handler method for MsgAcknowledgement. -func (k Keeper) Acknowledgement(goCtx context.Context, msg *channeltypes.MsgAcknowledgement) (*channeltypes.MsgAcknowledgementResponse, error) { +func (k Keeper) AcknowledgementUnchecked(goCtx context.Context, msg *channeltypes.MsgAcknowledgement) (*channeltypes.MsgAcknowledgementResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) relayer, err := sdk.AccAddressFromBech32(msg.Signer) diff --git a/modules/core/keeper/msg_server_override.go b/modules/core/keeper/msg_server_override.go index a0ee745b5ad..65d479c309a 100644 --- a/modules/core/keeper/msg_server_override.go +++ b/modules/core/keeper/msg_server_override.go @@ -2,6 +2,7 @@ package keeper import ( "context" + "fmt" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" @@ -21,36 +22,114 @@ func (k Keeper) ChannelOpenInit( goCtx context.Context, msg *channeltypes.MsgChannelOpenInit, ) (*channeltypes.MsgChannelOpenInitResponse, error) { + err := k.ensureNonVirtualSender(goCtx, msg.Channel, "ChannelOpenInit") + if err != nil { + return nil, err + } + return k.ChannelOpenInitUnchecked(goCtx, msg) +} + +func (k Keeper) ChannelOpenTry(goCtx context.Context, msg *channeltypes.MsgChannelOpenTry) (*channeltypes.MsgChannelOpenTryResponse, error) { + err := k.ensureNonVirtualConnectionsForChannel(goCtx, "ChannelOpenTry", msg.Channel) + if err != nil { + return nil, err + } + return k.ChannelOpenTryUnchecked(goCtx, msg) +} + +func (k Keeper) ChannelOpenAck(goCtx context.Context, msg *channeltypes.MsgChannelOpenAck) (*channeltypes.MsgChannelOpenAckResponse, error) { + err := k.ensureNonVirtualConnections(goCtx, "ChannelOpenAck", msg.PortId, msg.ChannelId) + if err != nil { + return nil, err + } + return k.ChannelOpenAckUnchecked(goCtx, msg) +} + +func (k Keeper) ChannelOpenConfirm(goCtx context.Context, msg *channeltypes.MsgChannelOpenConfirm) (*channeltypes.MsgChannelOpenConfirmResponse, error) { + err := k.ensureNonVirtualConnections(goCtx, "ChannelOpenConfirm", msg.PortId, msg.ChannelId) + if err != nil { + return nil, err + } + return k.ChannelOpenConfirmUnchecked(goCtx, msg) +} + +func (k Keeper) ChannelCloseInit(goCtx context.Context, msg *channeltypes.MsgChannelCloseInit) (*channeltypes.MsgChannelCloseInitResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) + channel, found := k.ChannelKeeper.GetChannel(ctx, msg.PortId, msg.ChannelId) + if !found { + return nil, sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", msg.PortId, msg.ChannelId) + } - // Ensure the first connection is not virtual; because ChannelOpenInit for virtual channel must go through - // VIBC.OpenIBCChannel endpoint - if isVirtual, connEnd := k.ChannelKeeper.IsVirtualConnection(ctx, msg.Channel.ConnectionHops[0]); isVirtual { - return nil, sdkerrors.Wrapf( - connectiontypes.ErrInvalidConnection, - "ChanelOpenInit can only be invoked directly on a non-virtual connection, connection: %v", - connEnd, - ) + err := k.ensureNonVirtualSender(goCtx, channel, "ChannelCloseInit") + if err != nil { + return nil, err } - return k.ChannelOpenInitUnchecked(goCtx, msg) + return k.ChannelCloseInitUnchecked(goCtx, msg) } -func (k Keeper) ChannelOpenTry( - goCtx context.Context, - msg *channeltypes.MsgChannelOpenTry, -) (*channeltypes.MsgChannelOpenTryResponse, error) { +func (k Keeper) ChannelCloseConfirm(goCtx context.Context, msg *channeltypes.MsgChannelCloseConfirm) (*channeltypes.MsgChannelCloseConfirmResponse, error) { + err := k.ensureNonVirtualConnections(goCtx, "ChannelCloseConfirm", msg.PortId, msg.ChannelId) + if err != nil { + return nil, err + } + return k.ChannelCloseConfirmUnchecked(goCtx, msg) +} + +func (k Keeper) ChannelCloseFrozen(goCtx context.Context, msg *channeltypes.MsgChannelCloseFrozen) (*channeltypes.MsgChannelCloseFrozenResponse, error) { + err := k.ensureNonVirtualConnections(goCtx, "ChannelCloseFrozen", msg.PortId, msg.ChannelId) + if err != nil { + return nil, err + } + return k.ChannelCloseFrozenUnchecked(goCtx, msg) +} + +func (k Keeper) RecvPacket(goCtx context.Context, msg *channeltypes.MsgRecvPacket) (*channeltypes.MsgRecvPacketResponse, error) { + err := k.ensureNonVirtualConnections(goCtx, "RecvPacket", msg.Packet.GetDestPort(), msg.Packet.GetDestChannel()) + if err != nil { + return nil, err + } + return k.RecvPacketUnchecked(goCtx, msg) +} + +func (k Keeper) Acknowledgement(goCtx context.Context, msg *channeltypes.MsgAcknowledgement) (*channeltypes.MsgAcknowledgementResponse, error) { + err := k.ensureNonVirtualConnections(goCtx, "Acknowledgement", msg.Packet.GetSourcePort(), msg.Packet.GetSourceChannel()) + if err != nil { + return nil, err + } + return k.AcknowledgementUnchecked(goCtx, msg) +} + +func (k Keeper) ensureNonVirtualSender(goCtx context.Context, channel channeltypes.Channel, methodName string) error { ctx := sdk.UnwrapSDKContext(goCtx) - // Ensure the first connection is not virtual; because ChannelOpenInit for virtual channel must go through - // VIBC.OpenIBCChannel endpoint - if isVirtual, connEnd := k.ChannelKeeper.IsVirtualConnection(ctx, msg.Channel.ConnectionHops[0]); isVirtual { - return nil, sdkerrors.Wrapf( + if isVirtual, connEnd := k.ChannelKeeper.IsVirtualConnection(ctx, channel.ConnectionHops[0]); isVirtual { + return sdkerrors.Wrapf( connectiontypes.ErrInvalidConnection, - "ChanelOpenTry can only be invoked directly on a non-virtual connection, connection: %v", - connEnd, + "%s can only be invoked directly on a non-virtual connection, connection: %v", + methodName, connEnd, ) } + return nil +} - return k.ChannelOpenTryUnchecked(goCtx, msg) +func (k Keeper) ensureNonVirtualConnections(goCtx context.Context, methodName, portID, channelID string) error { + ctx := sdk.UnwrapSDKContext(goCtx) + channel, found := k.ChannelKeeper.GetChannel(ctx, portID, channelID) + if !found { + return sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) + } + return k.ensureNonVirtualConnectionsForChannel(goCtx, methodName, channel) +} + +func (k Keeper) ensureNonVirtualConnectionsForChannel(goCtx context.Context, methodName string, channel channeltypes.Channel) error { + ctx := sdk.UnwrapSDKContext(goCtx) + isVirtual := k.ChannelKeeper.IsVirtualEndToVirtualEnd(ctx, channel.ConnectionHops) + if isVirtual { + return sdkerrors.Wrapf( + connectiontypes.ErrInvalidConnection, + fmt.Sprintf("%s can only be invoked directly on non-virtual connections", methodName), + ) + } + return nil }