Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Security fixes #7

Merged
merged 2 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions modules/core/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func (k Keeper) ChannelOpenInitUnchecked(goCtx context.Context, msg *channeltype
// ChannelOpenTry defines a rpc handler method for MsgChannelOpenTry.
// ChannelOpenTry will perform 04-channel checks, route to the application
// callback, and write an OpenTry channel into state upon successful execution.
func (k Keeper) ChannelOpenTry(goCtx context.Context, msg *channeltypes.MsgChannelOpenTry) (*channeltypes.MsgChannelOpenTryResponse, error) {
func (k Keeper) ChannelOpenTryUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelOpenTry) (*channeltypes.MsgChannelOpenTryResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// auto register virtual port if connection is virtual
Expand Down Expand Up @@ -265,7 +265,7 @@ func (k Keeper) ChannelOpenTry(goCtx context.Context, msg *channeltypes.MsgChann
// ChannelOpenAck defines a rpc handler method for MsgChannelOpenAck.
// ChannelOpenAck will perform 04-channel checks, route to the application
// callback, and write an OpenAck channel into state upon successful execution.
func (k Keeper) ChannelOpenAck(goCtx context.Context, msg *channeltypes.MsgChannelOpenAck) (*channeltypes.MsgChannelOpenAckResponse, error) {
func (k Keeper) ChannelOpenAckUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelOpenAck) (*channeltypes.MsgChannelOpenAckResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Lookup module by channel capability
Expand Down Expand Up @@ -307,7 +307,7 @@ func (k Keeper) ChannelOpenAck(goCtx context.Context, msg *channeltypes.MsgChann
// ChannelOpenConfirm defines a rpc handler method for MsgChannelOpenConfirm.
// ChannelOpenConfirm will perform 04-channel checks, route to the application
// callback, and write an OpenConfirm channel into state upon successful execution.
func (k Keeper) ChannelOpenConfirm(goCtx context.Context, msg *channeltypes.MsgChannelOpenConfirm) (*channeltypes.MsgChannelOpenConfirmResponse, error) {
func (k Keeper) ChannelOpenConfirmUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelOpenConfirm) (*channeltypes.MsgChannelOpenConfirmResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Lookup module by channel capability
Expand Down Expand Up @@ -345,7 +345,7 @@ func (k Keeper) ChannelOpenConfirm(goCtx context.Context, msg *channeltypes.MsgC
}

// ChannelCloseInit defines a rpc handler method for MsgChannelCloseInit.
func (k Keeper) ChannelCloseInit(goCtx context.Context, msg *channeltypes.MsgChannelCloseInit) (*channeltypes.MsgChannelCloseInitResponse, error) {
func (k Keeper) ChannelCloseInitUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelCloseInit) (*channeltypes.MsgChannelCloseInitResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)
// Lookup module by channel capability
module, cap, err := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.PortId, msg.ChannelId)
Expand Down Expand Up @@ -378,7 +378,7 @@ func (k Keeper) ChannelCloseInit(goCtx context.Context, msg *channeltypes.MsgCha
}

// ChannelCloseConfirm defines a rpc handler method for MsgChannelCloseConfirm.
func (k Keeper) ChannelCloseConfirm(goCtx context.Context, msg *channeltypes.MsgChannelCloseConfirm) (*channeltypes.MsgChannelCloseConfirmResponse, error) {
func (k Keeper) ChannelCloseConfirmUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelCloseConfirm) (*channeltypes.MsgChannelCloseConfirmResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Lookup module by channel capability
Expand Down Expand Up @@ -412,7 +412,7 @@ func (k Keeper) ChannelCloseConfirm(goCtx context.Context, msg *channeltypes.Msg
}

// ChannelCloseFrozen defines a rpc handler method for MsgChannelCloseFrozen.
func (k Keeper) ChannelCloseFrozen(goCtx context.Context, msg *channeltypes.MsgChannelCloseFrozen) (*channeltypes.MsgChannelCloseFrozenResponse, error) {
func (k Keeper) ChannelCloseFrozenUnchecked(goCtx context.Context, msg *channeltypes.MsgChannelCloseFrozen) (*channeltypes.MsgChannelCloseFrozenResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Lookup module by channel capability
Expand Down Expand Up @@ -447,7 +447,7 @@ func (k Keeper) ChannelCloseFrozen(goCtx context.Context, msg *channeltypes.MsgC
}

// RecvPacket defines a rpc handler method for MsgRecvPacket.
func (k Keeper) RecvPacket(goCtx context.Context, msg *channeltypes.MsgRecvPacket) (*channeltypes.MsgRecvPacketResponse, error) {
func (k Keeper) RecvPacketUnchecked(goCtx context.Context, msg *channeltypes.MsgRecvPacket) (*channeltypes.MsgRecvPacketResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

relayer, err := sdk.AccAddressFromBech32(msg.Signer)
Expand Down Expand Up @@ -679,7 +679,7 @@ func (k Keeper) TimeoutOnClose(goCtx context.Context, msg *channeltypes.MsgTimeo
}

// Acknowledgement defines a rpc handler method for MsgAcknowledgement.
func (k Keeper) Acknowledgement(goCtx context.Context, msg *channeltypes.MsgAcknowledgement) (*channeltypes.MsgAcknowledgementResponse, error) {
func (k Keeper) AcknowledgementUnchecked(goCtx context.Context, msg *channeltypes.MsgAcknowledgement) (*channeltypes.MsgAcknowledgementResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

relayer, err := sdk.AccAddressFromBech32(msg.Signer)
Expand Down
117 changes: 111 additions & 6 deletions modules/core/keeper/msg_server_override.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package keeper

import (
"context"
"fmt"

sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
Expand All @@ -17,14 +18,118 @@ This files contains tx msg endpoint methods that override the default IBC behavi
// ChannelOpenInit defines a rpc handler method for MsgChannelOpenInit.
// ChannelOpenInit will perform 04-channel checks, route to the application
// callback, and write an OpenInit channel into state upon successful execution.
func (k Keeper) ChannelOpenInit(goCtx context.Context, msg *channeltypes.MsgChannelOpenInit) (*channeltypes.MsgChannelOpenInitResponse, error) {
func (k Keeper) ChannelOpenInit(
goCtx context.Context,
msg *channeltypes.MsgChannelOpenInit,
) (*channeltypes.MsgChannelOpenInitResponse, error) {
err := k.ensureNonVirtualSender(goCtx, msg.Channel, "ChannelOpenInit")
if err != nil {
return nil, err
}
return k.ChannelOpenInitUnchecked(goCtx, msg)
}

func (k Keeper) ChannelOpenTry(goCtx context.Context, msg *channeltypes.MsgChannelOpenTry) (*channeltypes.MsgChannelOpenTryResponse, error) {
err := k.ensureNonVirtualConnectionsForChannel(goCtx, "ChannelOpenTry", msg.Channel)
if err != nil {
return nil, err
}
return k.ChannelOpenTryUnchecked(goCtx, msg)
}

func (k Keeper) ChannelOpenAck(goCtx context.Context, msg *channeltypes.MsgChannelOpenAck) (*channeltypes.MsgChannelOpenAckResponse, error) {
err := k.ensureNonVirtualConnections(goCtx, "ChannelOpenAck", msg.PortId, msg.ChannelId)
if err != nil {
return nil, err
}
return k.ChannelOpenAckUnchecked(goCtx, msg)
}

func (k Keeper) ChannelOpenConfirm(goCtx context.Context, msg *channeltypes.MsgChannelOpenConfirm) (*channeltypes.MsgChannelOpenConfirmResponse, error) {
err := k.ensureNonVirtualConnections(goCtx, "ChannelOpenConfirm", msg.PortId, msg.ChannelId)
if err != nil {
return nil, err
}
return k.ChannelOpenConfirmUnchecked(goCtx, msg)
}

func (k Keeper) ChannelCloseInit(goCtx context.Context, msg *channeltypes.MsgChannelCloseInit) (*channeltypes.MsgChannelCloseInitResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)
channel, found := k.ChannelKeeper.GetChannel(ctx, msg.PortId, msg.ChannelId)
if !found {
return nil, sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", msg.PortId, msg.ChannelId)
}

// Ensure the first connection is not virtual; because ChannelOpenInit for virtual channel must go through
// VIBC.OpenIBCChannel endpoint
if isVirtual, connEnd := k.ChannelKeeper.IsVirtualConnection(ctx, msg.Channel.ConnectionHops[0]); isVirtual {
return nil, sdkerrors.Wrapf(connectiontypes.ErrInvalidConnection, "ChanelOpenInit can only be invoked directly on a non-virtual connection, connection: %v", connEnd)
err := k.ensureNonVirtualSender(goCtx, channel, "ChannelCloseInit")
if err != nil {
return nil, err
}

return k.ChannelOpenInitUnchecked(goCtx, msg)
return k.ChannelCloseInitUnchecked(goCtx, msg)
}

func (k Keeper) ChannelCloseConfirm(goCtx context.Context, msg *channeltypes.MsgChannelCloseConfirm) (*channeltypes.MsgChannelCloseConfirmResponse, error) {
err := k.ensureNonVirtualConnections(goCtx, "ChannelCloseConfirm", msg.PortId, msg.ChannelId)
if err != nil {
return nil, err
}
return k.ChannelCloseConfirmUnchecked(goCtx, msg)
}

func (k Keeper) ChannelCloseFrozen(goCtx context.Context, msg *channeltypes.MsgChannelCloseFrozen) (*channeltypes.MsgChannelCloseFrozenResponse, error) {
err := k.ensureNonVirtualConnections(goCtx, "ChannelCloseFrozen", msg.PortId, msg.ChannelId)
if err != nil {
return nil, err
}
return k.ChannelCloseFrozenUnchecked(goCtx, msg)
}

func (k Keeper) RecvPacket(goCtx context.Context, msg *channeltypes.MsgRecvPacket) (*channeltypes.MsgRecvPacketResponse, error) {
err := k.ensureNonVirtualConnections(goCtx, "RecvPacket", msg.Packet.GetDestPort(), msg.Packet.GetDestChannel())
if err != nil {
return nil, err
}
return k.RecvPacketUnchecked(goCtx, msg)
}

func (k Keeper) Acknowledgement(goCtx context.Context, msg *channeltypes.MsgAcknowledgement) (*channeltypes.MsgAcknowledgementResponse, error) {
err := k.ensureNonVirtualConnections(goCtx, "Acknowledgement", msg.Packet.GetSourcePort(), msg.Packet.GetSourceChannel())
if err != nil {
return nil, err
}
return k.AcknowledgementUnchecked(goCtx, msg)
}

func (k Keeper) ensureNonVirtualSender(goCtx context.Context, channel channeltypes.Channel, methodName string) error {
ctx := sdk.UnwrapSDKContext(goCtx)

if isVirtual, connEnd := k.ChannelKeeper.IsVirtualConnection(ctx, channel.ConnectionHops[0]); isVirtual {
return sdkerrors.Wrapf(
connectiontypes.ErrInvalidConnection,
"%s can only be invoked directly on a non-virtual connection, connection: %v",
methodName, connEnd,
)
}
return nil
}

func (k Keeper) ensureNonVirtualConnections(goCtx context.Context, methodName, portID, channelID string) error {
ctx := sdk.UnwrapSDKContext(goCtx)
channel, found := k.ChannelKeeper.GetChannel(ctx, portID, channelID)
if !found {
return sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}
return k.ensureNonVirtualConnectionsForChannel(goCtx, methodName, channel)
}

func (k Keeper) ensureNonVirtualConnectionsForChannel(goCtx context.Context, methodName string, channel channeltypes.Channel) error {
ctx := sdk.UnwrapSDKContext(goCtx)
isVirtual := k.ChannelKeeper.IsVirtualEndToVirtualEnd(ctx, channel.ConnectionHops)
if isVirtual {
return sdkerrors.Wrapf(
connectiontypes.ErrInvalidConnection,
fmt.Sprintf("%s can only be invoked directly on non-virtual connections", methodName),
)
}
return nil
}
Loading