Skip to content

Commit

Permalink
feat: Add option to actually calculate payload hash
Browse files Browse the repository at this point in the history
  • Loading branch information
achetronic committed Oct 23, 2024
1 parent 8cf9d6a commit abf8152
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 45 deletions.
5 changes: 5 additions & 0 deletions api/config_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,17 @@ import (
)

type BifrostConfigT struct {
Common CommonT `yaml:"common"`
Listener ListenerT `yaml:"listener"`
Authentication AuthenticationT `yaml:"authentication"`
Modifiers []ModifierT `yaml:"modifiers"`
Target TargetT `yaml:"target"`
}

type CommonT struct {
EnablePayloadHashCalculation bool `yaml:"enablePayloadHashCalculation"`
}

type ListenerT struct {
Port string `yaml:"port"`
Host string `yaml:"host"`
Expand Down
4 changes: 2 additions & 2 deletions charts/bifrost/Chart.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ type: application
description: >-
A Helm chart for bifrost, a lightweight S3 proxy that re-signs requests between
your customers and buckets, supporting multiple client authentication methods.
version: 0.8.0 # chart version
appVersion: 0.8.0 # application version
version: 0.9.0 # chart version
appVersion: 0.9.0 # application version
kubeVersion: ">=1.22.0-0" # kubernetes version
home: https://github.com/freepik-company/bifrost
sources:
Expand Down
25 changes: 18 additions & 7 deletions docs/samples/bifrost.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,26 @@
# The whole file supports environment variables expansion,
# so you can use them in any part of the file

# (Optional)
common:
# This option enables performing hash calculation of the request's payload. If enabled, the proxy
# reads and stores the entire request body in a temporary file while calculating the hash.
# Doing this increase the security (hashes are actually verified) at the cost of a decrease in performance.
enablePayloadHashCalculation: true

# (Required) Listener which will attend incoming requests
listener:
port: 7777
host: 0.0.0.0

options:
readTimeout: 10s
writeTimeout: 10s
maxConcurrentConnections: 1000 # Zero (0) means no limit
disableKeepAlives: false
readTimeout: 0s # Zero (0) means no limit
writeTimeout: 0s # Zero (0) means no limit
maxConcurrentConnections: 0 # Zero (0) means no limit

# Disabling keep-alive connections decrease memory usage at the cost of CPU consumption increase
# due to the overhead of creating new connections for each request.
disableKeepAlives: true

# (Optional) Authentication configuration
authentication:
Expand Down Expand Up @@ -43,8 +53,6 @@ authentication:
s3:
signatureVerification: false

# TBD: Add support for other authentication types

# (Optional) List of modifiers to apply to the request before signing it
modifiers:
- type: path
Expand All @@ -64,4 +72,7 @@ target:
options:
dialTimeout: 10s
keepAlive: 30s
disableKeepAlives: false

# Disabling keep-alive connections decrease memory usage at the cost of CPU consumption increase
# due to the overhead of creating new connections for each request.
disableKeepAlives: true
87 changes: 61 additions & 26 deletions internal/httpserver/httpserver.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package httpserver

import (
"crypto/md5"
"crypto/sha256"
"crypto/tls"
"encoding/xml"
"fmt"
"hash"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -144,48 +142,53 @@ func getBucketCredential(request *http.Request) (*api.BucketCredentialT, error)
return nil, fmt.Errorf("unable to find bucket credential")
}

// getPayloadHash copy the request.Body content into a temporary file to calculate the hash.
// getPayloadHashFromHeader trust the X-Amz-Content-Sha256 header to extract the hash
// already calculated by the user's CLI
func getPayloadHashFromHeader(req *http.Request) (payloadHash string) {

payloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
payloadHashHeader := req.Header.Get("x-amz-content-sha256")
if payloadHashHeader != "" {
payloadHash = payloadHashHeader
}

return payloadHash
}

// getPayloadHashFromBody copy the request.Body content into a temporary file to calculate the hash.
// Once calculated, it returns the hash and the pointer to the file to use its content later
func getPayloadHash(hashType string, req *http.Request) (fileHash string, file *os.File, err error) {
func getPayloadHashFromBody(req *http.Request) (payloadHash string, payloadContent *os.File, err error) {

// Create temporary file to store the content
filePtr, err := os.CreateTemp(os.TempDir(), "req-payload-*")
payloadContent, err = os.CreateTemp(os.TempDir(), "bifrost-req-payload-*.tmp")
if err != nil {
return fileHash, file, fmt.Errorf("failed creating temp file: %s", err.Error())
return payloadHash, payloadContent, fmt.Errorf("failed creating temp file: %s", err.Error())
}

var hasher hash.Hash
switch hashType {
case "sha256":
hasher = sha256.New()
case "md5":
hasher = md5.New()
default:
return fileHash, file, fmt.Errorf("hash type not supported")
}
hasher := sha256.New()

// Write the request into (temporary file + hasher) at once
multiWriterEntity := io.MultiWriter(filePtr, hasher)
multiWriterEntity := io.MultiWriter(payloadContent, hasher)

// Create a new Reader entity that will spy the content of r.Body while the last it's being copied
spyReaderEntity := io.TeeReader(req.Body, multiWriterEntity)

//
_, err = io.Copy(io.Discard, spyReaderEntity)
if err != nil {
return fileHash, file, fmt.Errorf("failed copying data: %s", err.Error())
return payloadHash, payloadContent, fmt.Errorf("failed copying data: %s", err.Error())
}

//
fileHash = fmt.Sprintf("%x", hasher.Sum(nil))
filePtr.Seek(0, io.SeekStart)
payloadHash = fmt.Sprintf("%x", hasher.Sum(nil))
payloadContent.Seek(0, io.SeekStart)

return fileHash, filePtr, nil
return payloadHash, payloadContent, nil
}

// isValidSignature verifies the signature of the request using the provided bucket credential
// It produces another signature over the same request and compares them
func isValidSignature(bucketCredential *api.BucketCredentialT, request *http.Request) (bool, error) {
func isValidSignature(bucketCredential *api.BucketCredentialT, request *http.Request, payloadHash string) (bool, error) {

globals.Application.Logger.Debugf("[signature validation] client original request: %v", request)

Expand Down Expand Up @@ -219,7 +222,7 @@ func isValidSignature(bucketCredential *api.BucketCredentialT, request *http.Req
globals.Application.Logger.Debugf("[signature validation] simulated request before signing: %v", simulatedReq)

// Sign the faked request with provided credentials
err = signature.SignS3Version4(bucketCredential.AwsConfig, simulatedReq)
err = signature.SignS3Version4(bucketCredential.AwsConfig, simulatedReq, payloadHash)
if err != nil {
return false, fmt.Errorf("failed to sign simulated request: %s", err.Error())
}
Expand Down Expand Up @@ -297,7 +300,6 @@ func (s *HttpServer) handleRequest(response http.ResponseWriter, request *http.R
// Optional: Flush the encoder to ensure all data is written
encoder.Flush()
}

}()

defer request.Body.Close()
Expand All @@ -309,11 +311,37 @@ func (s *HttpServer) handleRequest(response http.ResponseWriter, request *http.R
return
}

// Calculate hash of the request payload to be used in verification and signature
var payloadHash string
var payloadContent *os.File

payloadHash = getPayloadHashFromHeader(request)

if (request.Method == http.MethodPost || request.Method == http.MethodPut) &&
globals.Application.Config.Common.EnablePayloadHashCalculation {

payloadHash, payloadContent, localErr = getPayloadHashFromBody(request)
if localErr != nil {
err = fmt.Errorf("failed calculating hash: %s", localErr.Error())
return
}

defer func() {
payloadContent.Close()

localErr = os.Remove(payloadContent.Name())
if localErr != nil {
err = fmt.Errorf("failed cleaning hash assets: %s", localErr.Error())
return
}
}()
}

// Verify the signature of the request when using S3 credentials for client authentication
if globals.Application.Config.Authentication.ClientCredentials.Type == "s3" &&
globals.Application.Config.Authentication.ClientCredentials.S3.SignatureVerification {

isValid, localErr := isValidSignature(bucketCredential, request)
isValid, localErr := isValidSignature(bucketCredential, request, payloadHash)
if localErr != nil {
err = fmt.Errorf("failed to validate request signature: %s", localErr.Error())
return
Expand Down Expand Up @@ -349,7 +377,14 @@ func (s *HttpServer) handleRequest(response http.ResponseWriter, request *http.R
targetRequestUrl := fmt.Sprintf("%s://%s%s",
globals.Application.Config.Target.Scheme, targetHostString, request.URL.Path+"?"+request.URL.RawQuery)

targetReq, localErr := http.NewRequest(request.Method, targetRequestUrl, request.Body)
// Read from the request or from the file depending on the params
payloadReader := request.Body
if (request.Method == http.MethodPost || request.Method == http.MethodPut) &&
globals.Application.Config.Common.EnablePayloadHashCalculation {
payloadReader = payloadContent
}

targetReq, localErr := http.NewRequest(request.Method, targetRequestUrl, payloadReader)
if localErr != nil {
err = fmt.Errorf("failed to create request: %s", localErr.Error())
return
Expand All @@ -372,7 +407,7 @@ func (s *HttpServer) handleRequest(response http.ResponseWriter, request *http.R
}

// Sign the request
localErr = signature.SignS3Version4(bucketCredential.AwsConfig, targetReq)
localErr = signature.SignS3Version4(bucketCredential.AwsConfig, targetReq, payloadHash)
if localErr != nil {
err = fmt.Errorf("failed to sign request: %s", localErr.Error())
return
Expand Down
11 changes: 1 addition & 10 deletions internal/signature/signature.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,13 @@ import (

// Sign a S3 request using AWS Signature Version 4.
// Ref:
func SignS3Version4(cfg *aws.Config, req *http.Request) (err error) {
func SignS3Version4(cfg *aws.Config, req *http.Request, payloadHash string) (err error) {

awsCredentials, err := cfg.Credentials.Retrieve(context.TODO())
if err != nil {
log.Fatalf("unable to get credentials, %v", err)
}

// Trust the X-Amz-Content-Sha256 header if it is present in the request.
// Calculating the payload hash is expensive, so we trust the client to provide it.
// TODO: Calculate the payload hash in future versions.
payloadHash := "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
payloadHashHeader := req.Header.Get("x-amz-content-sha256")
if payloadHashHeader != "" {
payloadHash = payloadHashHeader
}

// Trust the X-Amz-Date header if it is present in the request
signingTime := time.Now()
dateHeader := req.Header.Get("x-amz-date")
Expand Down

0 comments on commit abf8152

Please sign in to comment.