From 642df8f3853684dd3d7d1a881c4f5eb521d76ef6 Mon Sep 17 00:00:00 2001 From: Scott Fairclough <70711990+hexoscott@users.noreply.github.com> Date: Tue, 5 Nov 2024 14:42:06 +0000 Subject: [PATCH] limit modexp call to 8192 bit inputs (#1391) * limit modexp call to 8192 bit inputs * logic change to mod exp revert rules * mod len 0 logic in modExp * remove comment from modexp * more mod exp zk tweaks --- core/vm/contracts_zkevm.go | 65 +++++++++++++++++++++++++------ core/vm/contracts_zkevm_test.go | 69 +++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 11 deletions(-) create mode 100644 core/vm/contracts_zkevm_test.go diff --git a/core/vm/contracts_zkevm.go b/core/vm/contracts_zkevm.go index 10047230e26..037f2bc7c42 100644 --- a/core/vm/contracts_zkevm.go +++ b/core/vm/contracts_zkevm.go @@ -299,6 +299,34 @@ func (c *bigModExp_zkevm) RequiredGas(input []byte) uint64 { } else { input = input[:0] } + + // Retrieve the operands and execute the exponentiation + var ( + base = new(big.Int).SetBytes(getData(input, 0, baseLen.Uint64())) + exp = new(big.Int).SetBytes(getData(input, baseLen.Uint64(), expLen.Uint64())) + mod = new(big.Int).SetBytes(getData(input, baseLen.Uint64()+expLen.Uint64(), modLen.Uint64())) + baseBitLen = base.BitLen() + expBitLen = exp.BitLen() + modBitLen = mod.BitLen() + ) + + // zk special cases + // - if mod = 0 we consume gas as normal + // - if base is 0 and mod < 8192 we consume gas as normal + // - if neither of the above are true we check for reverts and return 0 gas fee + + if modBitLen == 0 { + // consume as normal - will return 0 + } else if baseBitLen == 0 { + if modBitLen > 8192 { + return 0 + } else { + // consume as normal - will return 0 + } + } else if baseBitLen > 8192 || expBitLen > 8192 || modBitLen > 8192 { + return 0 + } + // Retrieve the head 32 bytes of exp for the adjusted exponent length var expHead *big.Int if big.NewInt(int64(len(input))).Cmp(baseLen) <= 0 { @@ -373,21 +401,36 @@ func (c *bigModExp_zkevm) Run(input []byte) ([]byte, error) { } else { input = input[:0] } - // Handle a special case when both the base and mod length is zero - if baseLen == 0 && modLen == 0 { - return []byte{}, nil - } + // Retrieve the operands and execute the exponentiation var ( - base = new(big.Int).SetBytes(getData(input, 0, baseLen)) - exp = new(big.Int).SetBytes(getData(input, baseLen, expLen)) - mod = new(big.Int).SetBytes(getData(input, baseLen+expLen, modLen)) - v []byte + base = new(big.Int).SetBytes(getData(input, 0, baseLen)) + exp = new(big.Int).SetBytes(getData(input, baseLen, expLen)) + mod = new(big.Int).SetBytes(getData(input, baseLen+expLen, modLen)) + v []byte + baseBitLen = base.BitLen() + expBitLen = exp.BitLen() + modBitLen = mod.BitLen() ) + + if modBitLen == 0 { + return []byte{}, nil + } + + if baseBitLen == 0 { + if modBitLen > 8192 { + return nil, ErrExecutionReverted + } else { + return common.LeftPadBytes([]byte{}, int(modLen)), nil + } + } + + // limit to 8192 bits for base, exp, and mod in ZK + if baseBitLen > 8192 || expBitLen > 8192 || modBitLen > 8192 { + return nil, ErrExecutionReverted + } + switch { - case mod.BitLen() == 0: - // Modulo 0 is undefined, return zero - return common.LeftPadBytes([]byte{}, int(modLen)), nil case base.Cmp(libcommon.Big1) == 0: //If base == 1, then we can just return base % mod (if mod >= 1, which it is) v = base.Mod(base, mod).Bytes() diff --git a/core/vm/contracts_zkevm_test.go b/core/vm/contracts_zkevm_test.go new file mode 100644 index 00000000000..e3a8d27d3c1 --- /dev/null +++ b/core/vm/contracts_zkevm_test.go @@ -0,0 +1,69 @@ +package vm + +import ( + "testing" + "math/big" +) + +var ( + big0 = big.NewInt(0) + big10 = big.NewInt(10) + big8194 = big.NewInt(0).Lsh(big.NewInt(1), 8194) +) + +func Test_ModExpZkevm_Gas(t *testing.T) { + modExp := bigModExp_zkevm{enabled: true, eip2565: true} + + cases := map[string]struct { + base *big.Int + exp *big.Int + mod *big.Int + expected uint64 + }{ + "simple test": {big10, big10, big10, 200}, + "0 mod - normal gas": {big10, big10, big0, 200}, + "base 0 - mod < 8192 - normal gas": {big0, big10, big10, 200}, + "base 0 - mod > 8192 - 0 gas": {big0, big10, big8194, 0}, + "base over 8192 - 0 gas": {big8194, big10, big10, 0}, + "exp over 8192 - 0 gas": {big10, big8194, big10, 0}, + "mod over 8192 - 0 gas": {big10, big10, big8194, 0}, + } + + for name, test := range cases { + t.Run(name, func(t *testing.T) { + input := make([]byte, 0) + + base := len(test.base.Bytes()) + exp := len(test.exp.Bytes()) + mod := len(test.mod.Bytes()) + + input = append(input, uint64To32Bytes(base)...) + input = append(input, uint64To32Bytes(exp)...) + input = append(input, uint64To32Bytes(mod)...) + input = append(input, uint64ToDeterminedBytes(test.base, base)...) + input = append(input, uint64ToDeterminedBytes(test.exp, exp)...) + input = append(input, uint64ToDeterminedBytes(test.mod, mod)...) + + gas := modExp.RequiredGas(input) + + if gas != test.expected { + t.Errorf("Expected %d, got %d", test.expected, gas) + } + }) + } +} + +func uint64To32Bytes(input int) []byte { + bigInt := new(big.Int).SetUint64(uint64(input)) + bytes := bigInt.Bytes() + result := make([]byte, 32) + copy(result[32-len(bytes):], bytes) + return result +} + +func uint64ToDeterminedBytes(input *big.Int, length int) []byte { + bytes := input.Bytes() + result := make([]byte, length) + copy(result[length-len(bytes):], bytes) + return result +}