From 9b4b9a36067b24b740a554bb56fb325141b3238c Mon Sep 17 00:00:00 2001 From: Gabriel Paradiso Date: Wed, 29 Jan 2025 18:38:39 +0100 Subject: [PATCH] feat: validate decompresed binary size --- pkg/workflows/wasm/host/module.go | 50 ++++++++++++++++++---------- pkg/workflows/wasm/host/wasm_test.go | 20 +++++++++++ 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index 0e2214c30..85e4c4a68 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -71,12 +71,13 @@ func (r *store) delete(id string) { } var ( - defaultTickInterval = 100 * time.Millisecond - defaultTimeout = 10 * time.Second - defaultMinMemoryMBs = uint64(128) - DefaultInitialFuel = uint64(100_000_000) - defaultMaxFetchRequests = 5 - defaultMaxCompressedBinarySize = 10 * 1024 * 1024 // 10 MB + defaultTickInterval = 100 * time.Millisecond + defaultTimeout = 10 * time.Second + defaultMinMemoryMBs = uint64(128) + DefaultInitialFuel = uint64(100_000_000) + defaultMaxFetchRequests = 5 + defaultMaxCompressedBinarySize = 10 * 1024 * 1024 // 10 MB + defaultMaxDecompressedBinarySize = 100 * 1024 * 1024 // 100 MB ) type DeterminismConfig struct { @@ -84,16 +85,17 @@ type DeterminismConfig struct { Seed int64 } type ModuleConfig struct { - TickInterval time.Duration - Timeout *time.Duration - MaxMemoryMBs uint64 - MinMemoryMBs uint64 - InitialFuel uint64 - Logger logger.Logger - IsUncompressed bool - Fetch func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) - MaxFetchRequests int - MaxCompressedBinarySize uint64 + TickInterval time.Duration + Timeout *time.Duration + MaxMemoryMBs uint64 + MinMemoryMBs uint64 + InitialFuel uint64 + Logger logger.Logger + IsUncompressed bool + Fetch func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) + MaxFetchRequests int + MaxCompressedBinarySize uint64 + MaxDecompressedBinarySize uint64 // Labeler is used to emit messages from the module. Labeler custmsg.MessageEmitter @@ -173,6 +175,10 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig)) modCfg.MaxCompressedBinarySize = uint64(defaultMaxCompressedBinarySize) } + if modCfg.MaxDecompressedBinarySize == 0 { + modCfg.MaxDecompressedBinarySize = uint64(defaultMaxDecompressedBinarySize) + } + // Take the max of the min and the configured max memory mbs. // We do this because Go requires a minimum of 16 megabytes to run, // and local testing has shown that with less than the min, some @@ -196,10 +202,10 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig)) // validate the binary size before decompressing // this is to prevent decompression bombs if uint64(len(binary)) > modCfg.MaxCompressedBinarySize { - return nil, fmt.Errorf("binary size exceeds the maximum allowed size of %d bytes", modCfg.MaxCompressedBinarySize) + return nil, fmt.Errorf("compressed binary size exceeds the maximum allowed size of %d bytes", modCfg.MaxCompressedBinarySize) } - rdr := brotli.NewReader(bytes.NewBuffer(binary)) + rdr := io.LimitReader(brotli.NewReader(bytes.NewBuffer(binary)), int64(modCfg.MaxDecompressedBinarySize)) decompedBinary, err := io.ReadAll(rdr) if err != nil { return nil, fmt.Errorf("failed to decompress binary: %w", err) @@ -208,6 +214,14 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig)) binary = decompedBinary } + // Validate the decompressed binary size. + // io.LimitReader prevents decompression bombs by reading up to a set limit, but it will not return an error if the limit is reached. + // The Read() method will return io.EOF, and ReadAll will gracefully handle it and return nil. + // Because of this, we treat the limit as a non-inclusive limit. If the limit is reached, we return an error. + if uint64(len(binary)) == modCfg.MaxDecompressedBinarySize { + return nil, fmt.Errorf("decompressed binary size reached the maximum allowed size of %d bytes", modCfg.MaxDecompressedBinarySize) + } + mod, err := wasmtime.NewModule(engine, binary) if err != nil { return nil, fmt.Errorf("error creating wasmtime module: %w", err) diff --git a/pkg/workflows/wasm/host/wasm_test.go b/pkg/workflows/wasm/host/wasm_test.go index 3bd436ac9..a40cac8e4 100644 --- a/pkg/workflows/wasm/host/wasm_test.go +++ b/pkg/workflows/wasm/host/wasm_test.go @@ -950,6 +950,26 @@ func TestModule_CompressedBinarySize(t *testing.T) { }) } +func TestModule_DecompressedBinarySize(t *testing.T) { + t.Parallel() + + // compressed binary size is 4.121 MB + // decompressed binary size is 23.7 MB + binary := createTestBinary(successBinaryCmd, successBinaryLocation, false, t) + t.Run("decompressed binary size is within the limit", func(t *testing.T) { + customDecompressedBinarySize := uint64(24 * 1024 * 1024) + _, err := NewModule(&ModuleConfig{IsUncompressed: false, MaxDecompressedBinarySize: customDecompressedBinarySize, Logger: logger.Test(t)}, binary) + require.NoError(t, err) + }) + + t.Run("decompressed binary size is bigger than the limit", func(t *testing.T) { + customDecompressedBinarySize := uint64(3 * 1024 * 1024) + _, err := NewModule(&ModuleConfig{IsUncompressed: false, MaxDecompressedBinarySize: customDecompressedBinarySize, Logger: logger.Test(t)}, binary) + decompressedSizeExceeded := fmt.Sprintf("decompressed binary size reached the maximum allowed size of %d bytes", customDecompressedBinarySize) + require.ErrorContains(t, err, decompressedSizeExceeded) + }) +} + func TestModule_Sandbox_SleepIsStubbedOut(t *testing.T) { t.Parallel() ctx := tests.Context(t)