From 982794651c28cf44b451bc4a68ec54494fa50bc6 Mon Sep 17 00:00:00 2001 From: Vitalii Levitskii Date: Thu, 11 Apr 2024 12:28:07 +0200 Subject: [PATCH] Emit metrics from potentially changed header handlers --- CHANGELOG.md | 12 + docs/headers-handling.md | 86 ++++++ internal/observability/public.go | 76 +++++ internal/observability/public_test.go | 99 ++++++ transport/grpc/handler.go | 10 +- transport/grpc/headers.go | 88 ++++-- transport/grpc/headers_test.go | 335 ++++++++++++++++----- transport/grpc/outbound.go | 16 +- transport/grpc/response_writer.go | 16 +- transport/grpc/response_writer_test.go | 83 +++++ transport/tchannel/channel_outbound.go | 5 +- transport/tchannel/channel_transport.go | 5 +- transport/tchannel/handler.go | 149 ++------- transport/tchannel/handler_test.go | 32 +- transport/tchannel/header.go | 50 ++- transport/tchannel/header_test.go | 124 +++++++- transport/tchannel/outbound.go | 24 +- transport/tchannel/response_writer.go | 162 ++++++++++ transport/tchannel/response_writer_test.go | 80 +++++ transport/tchannel/tchannel_utils_test.go | 2 +- transport/tchannel/transport.go | 3 +- 21 files changed, 1186 insertions(+), 271 deletions(-) create mode 100644 docs/headers-handling.md create mode 100644 internal/observability/public.go create mode 100644 internal/observability/public_test.go create mode 100644 transport/grpc/response_writer_test.go create mode 100644 transport/tchannel/response_writer.go create mode 100644 transport/tchannel/response_writer_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index da56c06ce..1ae09b12d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,18 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] - Upgraded go version to 1.21, set toolchain version. - Reverted rpc-caller-procedure value setting. +- Preparation for the new header handling behavior. + +Starting from one of the next releases, following behaviour changes will be applied: +- Any inbound header with the prefix `rpc-` will be treated as a reserved header and will be ignored +(i.e. not forwarded to an application code) in both tchannel and grpc. +Currently, unknown headers with the prefix `rpc-` are forwarded to the application code in both transports. +- Any attempt to set request/response header with the prefix `rpc-` will result in an error in tchannel. +Currently, the same behavior is applied only to the grpc transport, while tchannel allows setting such headers. + +As an intermediate step, the `reserved_headers_stripped` and `reserved_headers_error` metrics +with `"component": "yarpc-header-migration"` constant tag and with `source` and `dest` variable tags +will be emitted to help to identify the edges that are affected by the changes. ## [1.72.1] - 2024-03-14 - tchannel: Renamed caller-procedure header from `$rpc$-caller-procedure` to `rpc-caller-procedure`. diff --git a/docs/headers-handling.md b/docs/headers-handling.md new file mode 100644 index 000000000..e1db6eef5 --- /dev/null +++ b/docs/headers-handling.md @@ -0,0 +1,86 @@ +# Headers handling + +Yarpc has unified API for getting and setting headers. Although implementations may wary +significantly from one transport to another. + +This document describes details of headers handling in Yarpc. + +# Existing behaviour + +## HTTP + +### Outbound - Request (writing via req.Headers.With) + +All application headers are prepended with an 'Rpc-Header' prefix. + +### Inbound - Request (Parsing) + +Predefined list of headers is read and stripped from the inbound request. + +Headers with prefix 'Rpc-Header-' will be forwarded to an application code (without prefix). + +Only headers explicitly specified in the config will be passed to the application code as is. If header name doesn't have 'x-' prefix, 'header %s does not begin with 'x-'' message is returned. + +### Inbound - Response (Writing) + +All application headers are prepended with an 'Rpc-Header' prefix. + +### Outbound - Response (Parsing) + +Headers with prefix 'Rpc-Header-' will be forwarded to an application code (without prefix). + +## TChannel + +### Outbound - Request (writing via req.Headers.With) + +Headers with any name may be added. + +### Inbound - Request (Parsing) + +Predefined list of headers (one header, actually) is read and stripped from the inbound request. + +All other headers are forwarded as is to an application code. + +### Inbound - Response (Writing) + +Attempting to add a header with a name listed as reserved leads to an error "cannot use reserved header key". + +### Outbound - Response (Parsing) + +Headers with the names listed as reserved are deleted. All other headers are forwarded to an application code as is. + +## GRPC + +### Outbound - Request (writing via req.Headers.With) + +Attempting to add headers with some of reserved names or with already set values lead to 'duplicate key' error. + +Attempting to add headers with 'rpc-' prefix leads to 'cannot use reserved header in application headers' error. + +### Inbound - Request (Parsing) + +Predefined list of headers is read and stripped from the inbound request. + +All other headers are forwarded as is to an application code. + +### Inbound - Response (Writing) + +Attempting to add headers with some of reserved names or already set values lead to 'duplicate key' error. + +Attempting to add headers with 'rpc-' prefix leads to 'cannot use reserved header in application headers' error. + +### Outbound - Response (Parsing) + +Headers with 'rpc-' prefix will be omitted from forwarding to an application code. + +# New behaviour + +## HTTP, TChannel, GRPC + +### Outbound - Request (writing via req.Headers.With) and Inbound - Response (Writing) + +Attempting to add a header with a 'prc-' or '$rpc$-' prefixes leads to an error "cannot use reserved header key". + +### Inbound - Request (Parsing) and Outbound - Response (Parsing) + +Unparsed headers with 'rpc-' or '$rpc$-' prefixes ignored, i.e. not forwarded to an application code. diff --git a/internal/observability/public.go b/internal/observability/public.go new file mode 100644 index 000000000..3806473ca --- /dev/null +++ b/internal/observability/public.go @@ -0,0 +1,76 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package observability + +import ( + "sync" + + "go.uber.org/net/metrics" +) + +var ( + reservedHeaderStripped *metrics.CounterVector + reservedHeaderError *metrics.CounterVector + + registerHeaderMetricsOnce sync.Once +) + +// IncReservedHeaderStripped increments the counter for reserved headers being stripped. +func IncReservedHeaderStripped(m *metrics.Scope, source, dest string) { + registerHeaderMetrics(m) + incHeaderMetric(reservedHeaderStripped, source, dest) +} + +// IncReservedHeaderError increments the counter for reserved headers led to error. +func IncReservedHeaderError(m *metrics.Scope, source, dest string) { + registerHeaderMetrics(m) + incHeaderMetric(reservedHeaderError, source, dest) +} + +func registerHeaderMetrics(m *metrics.Scope) { + if m == nil { + return + } + + registerHeaderMetricsOnce.Do(func() { + reservedHeaderStripped, _ = m.CounterVector(metrics.Spec{ + Name: "reserved_headers_stripped", + Help: "Total number of reserved headers being stripped.", + ConstTags: map[string]string{"component": "yarpc-header-migration"}, + VarTags: []string{"source", "dest"}, + }) + + reservedHeaderError, _ = m.CounterVector(metrics.Spec{ + Name: "reserved_headers_error", + Help: "Total number of reserved headers led to error.", + ConstTags: map[string]string{"component": "yarpc-header-migration"}, + VarTags: []string{"source", "dest"}, + }) + }) +} + +func incHeaderMetric(vector *metrics.CounterVector, source, dest string) { + if vector != nil { + if counter, err := vector.Get("source", source, "dest", dest); counter != nil && err == nil { + counter.Inc() + } + } +} diff --git a/internal/observability/public_test.go b/internal/observability/public_test.go new file mode 100644 index 000000000..67b9ad259 --- /dev/null +++ b/internal/observability/public_test.go @@ -0,0 +1,99 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package observability + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/net/metrics" +) + +func TestReservedHeaderMetrics(t *testing.T) { + m := metrics.New() + + t.Run("nil-scope", func(t *testing.T) { + IncReservedHeaderStripped(nil, "", "") + IncReservedHeaderError(nil, "", "") + }) + + t.Run("nil-counters", func(t *testing.T) { + // Counters registration called only once + registerHeaderMetrics(m.Scope()) + + var ( + registeredStripped = reservedHeaderStripped + registeredError = reservedHeaderError + ) + t.Cleanup(func() { + reservedHeaderStripped = registeredStripped + reservedHeaderError = registeredError + }) + reservedHeaderStripped = nil + reservedHeaderError = nil + + IncReservedHeaderStripped(m.Scope(), "", "") + IncReservedHeaderError(m.Scope(), "", "") + }) + + t.Run("inc-header-metric", func(t *testing.T) { + IncReservedHeaderStripped(m.Scope(), "source", "dest") + IncReservedHeaderError(m.Scope(), "source", "dest") + + IncReservedHeaderStripped(m.Scope(), "source", "dest") + IncReservedHeaderStripped(m.Scope(), "source", "dest-2") + IncReservedHeaderStripped(m.Scope(), "source-2", "dest-2") + + s := m.Snapshot() + + var ( + strippedFound, errorFound bool + ) + for _, c := range s.Counters { + if c.Name == "reserved_headers_stripped" { + strippedFound = true + + if c.Tags["source"] == "source" && c.Tags["dest"] == "dest" { + assert.Equal(t, int64(2), c.Value) + } else if c.Tags["source"] == "source" && c.Tags["dest"] == "dest-2" { + assert.Equal(t, int64(1), c.Value) + } else if c.Tags["source"] == "source-2" && c.Tags["dest"] == "dest-2" { + assert.Equal(t, int64(1), c.Value) + } else { + t.Errorf("unexpected counter: %v", c) + } + } else if c.Name == "reserved_headers_error" { + errorFound = true + + if c.Tags["source"] == "source" && c.Tags["dest"] == "dest" { + assert.Equal(t, int64(1), c.Value) + } else { + t.Errorf("unexpected counter: %v", c) + } + } else { + t.Errorf("unexpected counter: %v", c) + } + } + + assert.True(t, strippedFound) + assert.True(t, errorFound) + }) +} diff --git a/transport/grpc/handler.go b/transport/grpc/handler.go index 86ea56660..a1d3b7428 100644 --- a/transport/grpc/handler.go +++ b/transport/grpc/handler.go @@ -29,6 +29,7 @@ import ( "go.uber.org/yarpc/api/transport" "go.uber.org/yarpc/internal/bufferpool" "go.uber.org/yarpc/internal/grpcerrorcodes" + "go.uber.org/yarpc/internal/observability" "go.uber.org/yarpc/yarpcerrors" "go.uber.org/zap" "golang.org/x/net/context" @@ -88,7 +89,10 @@ func (h *handler) getBasicTransportRequest(ctx context.Context, streamMethod str if md == nil || !ok { return nil, yarpcerrors.Newf(yarpcerrors.CodeInternal, "cannot get metadata from ctx: %v", ctx) } - transportRequest, err := metadataToTransportRequest(md) + transportRequest, reportHeader, err := metadataToInboundRequest(md) + if reportHeader { + observability.IncReservedHeaderStripped(h.i.t.options.meter, transportRequest.Caller, transportRequest.Service) + } if err != nil { return nil, err } @@ -190,6 +194,10 @@ func (h *handler) handleUnary( err := h.handleUnaryBeforeErrorConversion(ctx, transportRequest, responseWriter, start, handler) err = handlerErrorToGRPCError(err, responseWriter) + if responseWriter.reportHeader { + observability.IncReservedHeaderError(h.i.t.options.meter, transportRequest.Caller, transportRequest.Service) + } + // Send the response attributes back and end the stream. // // Warning: SendMsg() holds onto these bytes after returning. Therefore, we diff --git a/transport/grpc/headers.go b/transport/grpc/headers.go index 95e8d2713..4c0812179 100644 --- a/transport/grpc/headers.go +++ b/transport/grpc/headers.go @@ -88,6 +88,12 @@ const ( contentTypeHeader = "content-type" ) +var ( + // enforceHeaderRules is a feature flag for a more strict error handling rules. + // See https://github.com/yarpc/yarpc-go/pull/2259 for more details. + enforceHeaderRules = false +) + // TODO: there are way too many repeat calls to strings.ToLower // Note that these calls are done indirectly, primarily through // transport.CanonicalizeHeaderKey @@ -96,11 +102,15 @@ func isReserved(header string) bool { return strings.HasPrefix(strings.ToLower(header), "rpc-") } -// transportRequestToMetadata will populate all reserved and application headers +func isReservedWithDollarSign(header string) bool { + return strings.HasPrefix(strings.ToLower(header), "$rpc$-") +} + +// outboundRequestToMetadata populates all reserved and application headers // from the Request into a new MD. -func transportRequestToMetadata(request *transport.Request) (metadata.MD, error) { - md := metadata.New(nil) - if err := multierr.Combine( +func outboundRequestToMetadata(request *transport.Request) (md metadata.MD, reportHeader bool, err error) { + md = metadata.New(nil) + err = multierr.Combine( addToMetadata(md, CallerHeader, request.Caller), addToMetadata(md, ServiceHeader, request.Service), addToMetadata(md, ShardKeyHeader, request.ShardKey), @@ -108,18 +118,23 @@ func transportRequestToMetadata(request *transport.Request) (metadata.MD, error) addToMetadata(md, RoutingDelegateHeader, request.RoutingDelegate), addToMetadata(md, EncodingHeader, string(request.Encoding)), addToMetadata(md, CallerProcedureHeader, request.CallerProcedure), - ); err != nil { - return md, err + ) + if err != nil { + return } - return md, addApplicationHeaders(md, request.Headers) + + reportHeader, err = addApplicationHeaders(md, request.Headers) + return } -// metadataToTransportRequest will populate the Request with all reserved and application +// metadataToInboundRequest populates the Request with all reserved and application // headers into a new Request, only not setting the Body field. -func metadataToTransportRequest(md metadata.MD) (*transport.Request, error) { +func metadataToInboundRequest(md metadata.MD) (*transport.Request, bool, error) { request := &transport.Request{ Headers: transport.NewHeadersWithCapacity(md.Len()), } + reportStrippedHeader := false + for header, values := range md { var value string switch len(values) { @@ -128,7 +143,7 @@ func metadataToTransportRequest(md metadata.MD) (*transport.Request, error) { case 1: value = values[0] default: - return nil, yarpcerrors.InvalidArgumentErrorf("header has more than one value: %s:%v", header, values) + return nil, reportStrippedHeader, yarpcerrors.InvalidArgumentErrorf("header has more than one value: %s:%v", header, values) } header = transport.CanonicalizeHeaderKey(header) switch header { @@ -153,10 +168,17 @@ func metadataToTransportRequest(md metadata.MD) (*transport.Request, error) { request.Encoding = transport.Encoding(getContentSubtype(value)) } default: + if isReserved(header) || isReservedWithDollarSign(header) { + reportStrippedHeader = true + if enforceHeaderRules { + continue + } + } request.Headers = request.Headers.With(header, value) } } - return request, nil + + return request, reportStrippedHeader, nil } func metadataToApplicationErrorMeta(responseMD metadata.MD) *transport.ApplicationErrorMeta { @@ -182,30 +204,51 @@ func metadataToApplicationErrorMeta(responseMD metadata.MD) *transport.Applicati } // addApplicationHeaders adds the headers to md. -func addApplicationHeaders(md metadata.MD, headers transport.Headers) error { +func addApplicationHeaders(md metadata.MD, headers transport.Headers) (reportHeader bool, err error) { for header, value := range headers.Items() { header = transport.CanonicalizeHeaderKey(header) + if isReserved(header) { - return yarpcerrors.InvalidArgumentErrorf("cannot use reserved header in application headers: %s", header) + err = yarpcerrors.InvalidArgumentErrorf("cannot use reserved header in application headers: %s", header) + return } - if err := addToMetadata(md, header, value); err != nil { - return err + + if isReservedWithDollarSign(header) { + reportHeader = true + if enforceHeaderRules { + err = yarpcerrors.InternalErrorf("cannot use reserved header in application headers: %s", header) + return + } + } + + if err = addToMetadata(md, header, value); err != nil { + return } } - return nil + + return } -// getApplicationHeaders returns the headers from md without any reserved headers. -func getApplicationHeaders(md metadata.MD) (transport.Headers, error) { +// getOutboundResponseApplicationHeaders returns the headers from md without any reserved headers. +func getOutboundResponseApplicationHeaders(md metadata.MD) (transport.Headers, bool, error) { if len(md) == 0 { - return transport.Headers{}, nil + return transport.Headers{}, false, nil } + headers := transport.NewHeadersWithCapacity(md.Len()) + reportHeader := false + for header, values := range md { header = transport.CanonicalizeHeaderKey(header) if isReserved(header) { continue } + if isReservedWithDollarSign(header) { + reportHeader = true + if enforceHeaderRules { + continue + } + } var value string switch len(values) { case 0: @@ -213,16 +256,19 @@ func getApplicationHeaders(md metadata.MD) (transport.Headers, error) { case 1: value = values[0] default: - return headers, yarpcerrors.InvalidArgumentErrorf("header has more than one value: %s:%v", header, values) + return transport.Headers{}, reportHeader, yarpcerrors.InvalidArgumentErrorf("header has more than one value: %s:%v", header, values) } headers = headers.With(header, value) } - return headers, nil + return headers, reportHeader, nil } // add to md // return error if key already in md func addToMetadata(md metadata.MD, key string, value string) error { + if md == nil { + return nil + } if value == "" { return nil } diff --git a/transport/grpc/headers_test.go b/transport/grpc/headers_test.go index b23c7eab3..e06a9537f 100644 --- a/transport/grpc/headers_test.go +++ b/transport/grpc/headers_test.go @@ -24,23 +24,21 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "go.uber.org/yarpc/api/transport" "go.uber.org/yarpc/yarpcerrors" "google.golang.org/grpc/metadata" ) func TestMetadataToTransportRequest(t *testing.T) { - t.Parallel() - tests := []struct { - Name string - MD metadata.MD - TransportRequest *transport.Request - Error error + tests := map[string]struct { + md metadata.MD + req *transport.Request + enforceHeaderRules bool + expErr error + expReportHeader bool }{ - { - Name: "Basic", - MD: metadata.Pairs( + "basic": { + md: metadata.Pairs( CallerHeader, "example-caller", ServiceHeader, "example-service", ShardKeyHeader, "example-shard-key", @@ -51,7 +49,7 @@ func TestMetadataToTransportRequest(t *testing.T) { "foo", "bar", "baz", "bat", ), - TransportRequest: &transport.Request{ + req: &transport.Request{ Caller: "example-caller", Service: "example-service", ShardKey: "example-shard-key", @@ -65,9 +63,8 @@ func TestMetadataToTransportRequest(t *testing.T) { }), }, }, - { - Name: "Content-type", - MD: metadata.Pairs( + "content-type": { + md: metadata.Pairs( CallerHeader, "example-caller", ServiceHeader, "example-service", ShardKeyHeader, "example-shard-key", @@ -77,7 +74,7 @@ func TestMetadataToTransportRequest(t *testing.T) { "foo", "bar", "baz", "bat", ), - TransportRequest: &transport.Request{ + req: &transport.Request{ Caller: "example-caller", Service: "example-service", ShardKey: "example-shard-key", @@ -90,9 +87,8 @@ func TestMetadataToTransportRequest(t *testing.T) { }), }, }, - { - Name: "Content-type overridden", - MD: metadata.Pairs( + "content-type-overridden": { + md: metadata.Pairs( CallerHeader, "example-caller", ServiceHeader, "example-service", ShardKeyHeader, "example-shard-key", @@ -103,7 +99,7 @@ func TestMetadataToTransportRequest(t *testing.T) { "foo", "bar", "baz", "bat", ), - TransportRequest: &transport.Request{ + req: &transport.Request{ Caller: "example-caller", Service: "example-service", ShardKey: "example-shard-key", @@ -116,27 +112,63 @@ func TestMetadataToTransportRequest(t *testing.T) { }), }, }, + "Reserved header key with rpc prefix in application headers": { + md: metadata.Pairs("rpc-any", "any-value"), + req: &transport.Request{ + Headers: transport.HeadersFromMap(map[string]string{"rpc-any": "any-value"}), + }, + expReportHeader: true, + }, + "Reserved header key with $rpc$ prefix in application headers": { + md: metadata.Pairs("$rpc$-any", "any-value"), + req: &transport.Request{ + Headers: transport.HeadersFromMap(map[string]string{"$rpc$-any": "any-value"}), + }, + expReportHeader: true, + }, + "Reserved headers rules are enforced": { + md: metadata.Pairs( + CallerHeader, "example-caller", + ServiceHeader, "example-service", + "rpc-any", "any-value", + "$rpc$-any", "any-value", + "foo", "bar", + "baz", "bat", + ), + req: &transport.Request{ + Caller: "example-caller", + Service: "example-service", + Headers: transport.HeadersFromMap(map[string]string{ + "foo": "bar", + "baz": "bat", + }), + }, + enforceHeaderRules: true, + expReportHeader: true, + }, } - for _, tt := range tests { - t.Run(tt.Name, func(t *testing.T) { - transportRequest, err := metadataToTransportRequest(tt.MD) - require.Equal(t, tt.Error, err) - require.Equal(t, tt.TransportRequest, transportRequest) + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + switchEnforceHeaderRules(t, tt.enforceHeaderRules) + + transportRequest, reportHeader, err := metadataToInboundRequest(tt.md) + assert.Equal(t, tt.expErr, err) + assert.Equal(t, tt.req, transportRequest) + assert.Equal(t, tt.expReportHeader, reportHeader) }) } } func TestTransportRequestToMetadata(t *testing.T) { - t.Parallel() - for _, tt := range []struct { - Name string - MD metadata.MD - TransportRequest *transport.Request - Error error + for name, tt := range map[string]struct { + md metadata.MD + req *transport.Request + enforceHeaderRules bool + expErr error + expReportHeader bool }{ - { - Name: "Basic", - MD: metadata.Pairs( + "basic": { + md: metadata.Pairs( CallerHeader, "example-caller", ServiceHeader, "example-service", ShardKeyHeader, "example-shard-key", @@ -147,7 +179,7 @@ func TestTransportRequestToMetadata(t *testing.T) { "foo", "bar", "baz", "bat", ), - TransportRequest: &transport.Request{ + req: &transport.Request{ Caller: "example-caller", Service: "example-service", ShardKey: "example-shard-key", @@ -161,21 +193,44 @@ func TestTransportRequestToMetadata(t *testing.T) { }), }, }, - { - Name: "Reserved header key in application headers", - MD: metadata.Pairs(), - TransportRequest: &transport.Request{ + "Reserved header key in application headers": { + md: metadata.Pairs(), + req: &transport.Request{ Headers: transport.HeadersFromMap(map[string]string{ CallerHeader: "example-caller", }), }, - Error: yarpcerrors.InvalidArgumentErrorf("cannot use reserved header in application headers: %s", CallerHeader), + expErr: yarpcerrors.InvalidArgumentErrorf("cannot use reserved header in application headers: %s", CallerHeader), + }, + "Reserved header key with $rpc$ prefix in application headers": { + md: metadata.Pairs("$rpc$-any", "example-caller"), + req: &transport.Request{ + Headers: transport.HeadersFromMap(map[string]string{ + "$rpc$-any": "example-caller", + }), + }, + expErr: nil, + expReportHeader: true, + }, + "Reserved header key with $rpc$ prefix in application headers with enforced rules": { + md: metadata.Pairs(), + req: &transport.Request{ + Headers: transport.HeadersFromMap(map[string]string{ + "$rpc$-any": "example-caller", + }), + }, + enforceHeaderRules: true, + expErr: yarpcerrors.InternalErrorf("cannot use reserved header in application headers: $rpc$-any"), + expReportHeader: true, }, } { - t.Run(tt.Name, func(t *testing.T) { - md, err := transportRequestToMetadata(tt.TransportRequest) - require.Equal(t, tt.Error, err) - require.Equal(t, tt.MD, md) + t.Run(name, func(t *testing.T) { + switchEnforceHeaderRules(t, tt.enforceHeaderRules) + + md, reportHeader, err := outboundRequestToMetadata(tt.req) + assert.Equal(t, tt.expErr, err) + assert.Equal(t, tt.md, md) + assert.Equal(t, tt.expReportHeader, reportHeader) }) } } @@ -203,6 +258,12 @@ func TestIsReserved(t *testing.T) { assert.True(t, isReserved(RoutingDelegateHeader)) assert.True(t, isReserved(EncodingHeader)) assert.True(t, isReserved("rpc-foo")) + assert.False(t, isReserved("$rpc$-foo")) +} + +func TestIsReservedWithDollarSign(t *testing.T) { + assert.False(t, isReservedWithDollarSign("rpc-foo")) + assert.True(t, isReservedWithDollarSign("$rpc$-foo")) } func TestMDReadWriterDuplicateKey(t *testing.T) { @@ -216,55 +277,181 @@ func TestMDReadWriterDuplicateKey(t *testing.T) { } func TestGetApplicationHeaders(t *testing.T) { - tests := []struct { - msg string - meta metadata.MD - wantHeaders map[string]string - wantErr string + tests := map[string]struct { + md metadata.MD + enforceHeaderRules bool + expHeaders map[string]string + expErr error + expReportHeader bool }{ - { - msg: "nil", - meta: nil, - wantHeaders: nil, + "nil": { + md: nil, + expHeaders: nil, }, - { - msg: "empty", - meta: metadata.MD{}, - wantHeaders: nil, + "empty": { + md: metadata.MD{}, + expHeaders: nil, }, - { - msg: "success", - meta: metadata.MD{ + "success": { + md: metadata.MD{ "rpc-service": []string{"foo"}, // reserved header "test-header-empty": []string{}, // no value "test-header-valid-1": []string{"test-value-1"}, "test-Header-Valid-2": []string{"test-value-2"}, }, - wantHeaders: map[string]string{ + expHeaders: map[string]string{ "test-header-valid-1": "test-value-1", "test-header-valid-2": "test-value-2", }, }, - { - msg: "error: multiple values for one header", - meta: metadata.MD{ + "error: multiple values for one header": { + md: metadata.MD{ "test-header-valid": []string{"test-value"}, "test-header-dup": []string{"test-value-1", "test-value-2"}, }, - wantErr: "header has more than one value: test-header-dup:[test-value-1 test-value-2]", + expErr: yarpcerrors.InvalidArgumentErrorf("header has more than one value: test-header-dup:[test-value-1 test-value-2]"), + }, + "reserved header": { + md: metadata.MD{ + "$rpc$-any": []string{"test-value"}, + }, + expHeaders: map[string]string{"$rpc$-any": "test-value"}, + expReportHeader: true, + }, + "reserved header with enforced header rules": { + md: metadata.MD{ + "rpc-any": []string{"test-value"}, + "$rpc$-any": []string{"test-value"}, + "foo": []string{"bar"}, + }, + enforceHeaderRules: true, + expHeaders: map[string]string{"foo": "bar"}, + expReportHeader: true, }, } - for _, tt := range tests { - t.Run(tt.msg, func(t *testing.T) { - got, err := getApplicationHeaders(tt.meta) - if tt.wantErr != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantErr, "unexpecte error message") - return - } - require.NoError(t, err, "failed to extract application headers") - assert.Equal(t, tt.wantHeaders, got.Items(), "unexpected headers") + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + switchEnforceHeaderRules(t, tt.enforceHeaderRules) + + headers, reportHeader, err := getOutboundResponseApplicationHeaders(tt.md) + assert.Equal(t, tt.expErr, err) + assert.Equal(t, tt.expReportHeader, reportHeader) + assert.Equal(t, tt.expHeaders, headers.Items()) + }) + } +} + +func TestAddApplicationHeaders(t *testing.T) { + dp := map[string]struct { + md metadata.MD + h transport.Headers + enforceHeaderRules bool + expMD metadata.MD + expErr error + expReportHeader bool + }{ + "success": { + md: metadata.Pairs("foo", "bar"), + h: transport.HeadersFromMap(map[string]string{ + "baz": "qux", + }), + expMD: metadata.Pairs("foo", "bar", "baz", "qux"), + }, + "reserved-rpc-prefix": { + md: metadata.Pairs("foo", "bar"), + h: transport.HeadersFromMap(map[string]string{ + "rpc-baz": "qux", + }), + expMD: metadata.Pairs("foo", "bar"), + expErr: yarpcerrors.InvalidArgumentErrorf("cannot use reserved header in application headers: rpc-baz"), + expReportHeader: false, // it's not a new behaviour + }, + "reserved-dollar-rpc-prefix": { + md: metadata.Pairs("foo", "bar"), + h: transport.HeadersFromMap(map[string]string{ + "$rpc$-baz": "qux", + }), + expMD: metadata.Pairs("foo", "bar", "$rpc$-baz", "qux"), + expErr: nil, + expReportHeader: true, + }, + "reserved-dollar-rpc-prefix-enforced-rule": { + md: metadata.Pairs("foo", "bar"), + h: transport.HeadersFromMap(map[string]string{ + "$rpc$-baz": "qux", + }), + enforceHeaderRules: true, + expMD: metadata.Pairs("foo", "bar"), + expErr: yarpcerrors.InternalErrorf("cannot use reserved header in application headers: $rpc$-baz"), + expReportHeader: true, + }, + } + + for name, tt := range dp { + t.Run(name, func(t *testing.T) { + switchEnforceHeaderRules(t, tt.enforceHeaderRules) + + reportHeader, err := addApplicationHeaders(tt.md, tt.h) + assert.Equal(t, err, tt.expErr) + assert.Equal(t, tt.expMD, tt.md) + assert.Equal(t, tt.expReportHeader, reportHeader) + }) + + } +} + +func TestAddToMetadata(t *testing.T) { + dp := map[string]struct { + md metadata.MD + key string + value string + expErr error + expMD metadata.MD + }{ + "nil-md": { + md: nil, + key: "foo", + value: "bar", + expMD: nil, + }, + "empty-value-ignored": { + md: metadata.Pairs(), + key: "foo", + value: "", + expMD: metadata.Pairs(), + }, + "duplicate-key": { + md: metadata.Pairs("foo", "bar"), + key: "foo", + value: "baz", + expErr: yarpcerrors.InvalidArgumentErrorf("duplicate key: foo"), + expMD: metadata.Pairs("foo", "bar"), + }, + "success": { + md: metadata.Pairs("foo", "bar"), + key: "baz", + value: "qux", + expMD: metadata.Pairs("foo", "bar", "baz", "qux"), + }, + } + + for name, tt := range dp { + t.Run(name, func(t *testing.T) { + err := addToMetadata(tt.md, tt.key, tt.value) + assert.Equal(t, err, tt.expErr) + assert.Equal(t, tt.expMD, tt.md) }) } } + +func switchEnforceHeaderRules(t *testing.T, cond bool) { + if !cond { + return + } + + enforceHeaderRules = true + t.Cleanup(func() { + enforceHeaderRules = false + }) +} diff --git a/transport/grpc/outbound.go b/transport/grpc/outbound.go index 7b585ecaa..dfa92d2d1 100644 --- a/transport/grpc/outbound.go +++ b/transport/grpc/outbound.go @@ -34,6 +34,7 @@ import ( "go.uber.org/yarpc/api/transport" "go.uber.org/yarpc/api/x/introspection" "go.uber.org/yarpc/internal/grpcerrorcodes" + "go.uber.org/yarpc/internal/observability" intyarpcerrors "go.uber.org/yarpc/internal/yarpcerrors" peerchooser "go.uber.org/yarpc/peer" "go.uber.org/yarpc/peer/hostport" @@ -122,7 +123,10 @@ func (o *Outbound) Call(ctx context.Context, request *transport.Request) (*trans var responseMD metadata.MD invokeErr := o.invoke(ctx, request, &responseBody, &responseMD, start) - responseHeaders, err := getApplicationHeaders(responseMD) + responseHeaders, reportHeader, err := getOutboundResponseApplicationHeaders(responseMD) + if reportHeader { + observability.IncReservedHeaderStripped(o.t.options.meter, request.Caller, request.Service) + } if err != nil { return nil, err } @@ -161,7 +165,10 @@ func (o *Outbound) invoke( responseMD *metadata.MD, start time.Time, ) (retErr error) { - md, err := transportRequestToMetadata(request) + md, reportHeader, err := outboundRequestToMetadata(request) + if reportHeader { + observability.IncReservedHeaderError(o.t.options.meter, request.Caller, request.Service) + } if err != nil { return err } @@ -301,7 +308,10 @@ func (o *Outbound) stream( return nil, err } - md, err := transportRequestToMetadata(treq) + md, reportHeader, err := outboundRequestToMetadata(treq) + if reportHeader { + observability.IncReservedHeaderError(o.t.options.meter, treq.Caller, treq.Service) + } if err != nil { return nil, err } diff --git a/transport/grpc/response_writer.go b/transport/grpc/response_writer.go index c60e20ba5..041008da3 100644 --- a/transport/grpc/response_writer.go +++ b/transport/grpc/response_writer.go @@ -34,9 +34,10 @@ var ( ) type responseWriter struct { - buffer *bytes.Buffer - md metadata.MD - headerErr error + buffer *bytes.Buffer + md metadata.MD + headerErr error + reportHeader bool } func newResponseWriter() *responseWriter { @@ -58,7 +59,14 @@ func (r *responseWriter) AddHeaders(headers transport.Headers) { if r.md == nil { r.md = metadata.New(nil) } - r.headerErr = multierr.Combine(r.headerErr, addApplicationHeaders(r.md, headers)) + + reportHeader, err := addApplicationHeaders(r.md, headers) + if err != nil { + r.headerErr = multierr.Combine(r.headerErr, err) + } + if reportHeader { + r.reportHeader = true + } } func (r *responseWriter) SetApplicationError() { diff --git a/transport/grpc/response_writer_test.go b/transport/grpc/response_writer_test.go new file mode 100644 index 000000000..b25a37feb --- /dev/null +++ b/transport/grpc/response_writer_test.go @@ -0,0 +1,83 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package grpc + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/multierr" + "go.uber.org/yarpc/api/transport" + "go.uber.org/yarpc/yarpcerrors" + "google.golang.org/grpc/metadata" +) + +func TestResponseWriterAddHeaders(t *testing.T) { + dp := map[string]struct { + h transport.Headers + md metadata.MD + expErr error + expReportHeader bool + expMD metadata.MD + }{ + "md-is-nil": { + h: transport.NewHeaders().With("foo", "bar"), + md: nil, + expMD: metadata.Pairs("foo", "bar"), + }, + "success": { + h: transport.NewHeaders().With("foo", "bar"), + md: metadata.Pairs(), + expMD: metadata.Pairs("foo", "bar"), + }, + "reserved-header-used": { + h: transport.NewHeaders().With("rpc-any", "any-value"), + md: metadata.Pairs(), + expErr: yarpcerrors.InvalidArgumentErrorf("cannot use reserved header in application headers: rpc-any"), + expMD: metadata.Pairs(), + }, + "report-header": { + h: transport.NewHeaders().With("$rpc$-any", "any-value"), + md: metadata.Pairs(), + expMD: metadata.Pairs("$rpc$-any", "any-value"), + expReportHeader: true, + }, + } + + for name, tt := range dp { + t.Run(name, func(t *testing.T) { + rw := newResponseWriter() + rw.md = tt.md + + rw.AddHeaders(tt.h) + if tt.expErr != nil { + errs := multierr.Errors(rw.headerErr) + require.Len(t, errs, 1) + assert.Equal(t, tt.expErr, errs[0]) + } else { + assert.NoError(t, rw.headerErr) + } + assert.Equal(t, tt.expReportHeader, rw.reportHeader) + assert.Equal(t, tt.expMD, rw.md) + }) + } +} diff --git a/transport/tchannel/channel_outbound.go b/transport/tchannel/channel_outbound.go index 21e75526f..2ced00bba 100644 --- a/transport/tchannel/channel_outbound.go +++ b/transport/tchannel/channel_outbound.go @@ -26,6 +26,7 @@ import ( "github.com/uber/tchannel-go" "go.uber.org/yarpc/api/transport" "go.uber.org/yarpc/api/x/introspection" + "go.uber.org/yarpc/internal/observability" intyarpcerrors "go.uber.org/yarpc/internal/yarpcerrors" "go.uber.org/yarpc/pkg/errors" "go.uber.org/yarpc/pkg/lifecycle" @@ -195,7 +196,9 @@ func (o *ChannelOutbound) Call(ctx context.Context, req *transport.Request) (*tr } err = getResponseError(headers) - deleteReservedHeaders(headers) + if deleteReservedHeaders(headers) { + observability.IncReservedHeaderStripped(o.transport.meter, req.Caller, req.Service) + } resp := &transport.Response{ Headers: headers, diff --git a/transport/tchannel/channel_transport.go b/transport/tchannel/channel_transport.go index 5faf47173..e4e26f0e9 100644 --- a/transport/tchannel/channel_transport.go +++ b/transport/tchannel/channel_transport.go @@ -25,6 +25,7 @@ import ( "github.com/opentracing/opentracing-go" "github.com/uber/tchannel-go" + "go.uber.org/net/metrics" "go.uber.org/yarpc/api/transport" "go.uber.org/yarpc/pkg/lifecycle" "go.uber.org/zap" @@ -87,8 +88,9 @@ func (options transportOptions) newChannelTransport() *ChannelTransport { addr: options.addr, tracer: options.tracer, logger: logger.Named("tchannel"), + meter: options.meter, originalHeaders: options.originalHeaders, - newResponseWriter: newHandlerWriter, + newResponseWriter: newResponseWriter, } } @@ -102,6 +104,7 @@ type ChannelTransport struct { addr string tracer opentracing.Tracer logger *zap.Logger + meter *metrics.Scope router transport.Router originalHeaders bool newResponseWriter func(inboundCallResponse, tchannel.Format, headerCase) responseWriter diff --git a/transport/tchannel/handler.go b/transport/tchannel/handler.go index 421b9aa27..90263aced 100644 --- a/transport/tchannel/handler.go +++ b/transport/tchannel/handler.go @@ -23,15 +23,15 @@ package tchannel import ( "bytes" "context" - "fmt" - "strconv" "time" "github.com/opentracing/opentracing-go" "github.com/uber/tchannel-go" "go.uber.org/multierr" + "go.uber.org/net/metrics" "go.uber.org/yarpc/api/transport" "go.uber.org/yarpc/internal/bufferpool" + "go.uber.org/yarpc/internal/observability" "go.uber.org/yarpc/pkg/errors" "go.uber.org/yarpc/yarpcerrors" "go.uber.org/zap" @@ -72,18 +72,17 @@ type inboundCallResponse interface { SetApplicationError() error } -// responseWriter provides an interface similar to handlerWriter. -// -// It allows us to control handlerWriter during testing. +// responseWriter enhances transport.ResponseWriter interface with transport specific +// methods. type responseWriter interface { - AddHeaders(h transport.Headers) - AddHeader(key string, value string) + transport.ResponseWriter + + AddSystemHeader(key string, value string) Close() error ReleaseBuffer() IsApplicationError() bool - SetApplicationError() SetApplicationErrorMeta(meta *transport.ApplicationErrorMeta) - Write(s []byte) (int, error) + IsReservedHeaderUsed() bool } // tchannelCall wraps a TChannel InboundCall into an inboundCall. @@ -103,6 +102,7 @@ type handler struct { tracer opentracing.Tracer headerCase headerCase logger *zap.Logger + meter *metrics.Scope newResponseWriter func(inboundCallResponse, tchannel.Format, headerCase) responseWriter excludeServiceHeaderInResponse bool } @@ -118,10 +118,13 @@ func (h handler) handle(ctx context.Context, call inboundCall) { if !h.excludeServiceHeaderInResponse { // echo accepted rpc-service in response header - responseWriter.AddHeader(ServiceHeaderKey, call.ServiceName()) + responseWriter.AddSystemHeader(ServiceHeaderKey, call.ServiceName()) } err := h.callHandler(ctx, call, responseWriter) + if responseWriter.IsReservedHeaderUsed() { + observability.IncReservedHeaderError(h.meter, call.CallerName(), call.ServiceName()) + } // black-hole requests on resource exhausted errors if yarpcerrors.FromError(err).Code() == yarpcerrors.CodeResourceExhausted { @@ -147,12 +150,12 @@ func (h handler) handle(ctx context.Context, call inboundCall) { // TODO: what to do with error? we could have a whole complicated scheme to // return a SystemError here, might want to do that text, _ := status.Code().MarshalText() - responseWriter.AddHeader(ErrorCodeHeaderKey, string(text)) + responseWriter.AddSystemHeader(ErrorCodeHeaderKey, string(text)) if status.Name() != "" { - responseWriter.AddHeader(ErrorNameHeaderKey, status.Name()) + responseWriter.AddSystemHeader(ErrorNameHeaderKey, status.Name()) } if status.Message() != "" { - responseWriter.AddHeader(ErrorMessageHeaderKey, status.Message()) + responseWriter.AddSystemHeader(ErrorMessageHeaderKey, status.Message()) } } if reswErr := responseWriter.Close(); reswErr != nil && !clientTimedOut { @@ -186,9 +189,10 @@ func (h handler) callHandler(ctx context.Context, call inboundCall, responseWrit return errors.RequestHeadersDecodeError(treq, err) } - // callerProcedure is a rpc header but recevied in application headers, so moving this header to transprotRequest - // by updating treq.CallerProcedure. - treq = headerCallerProcedureToRequest(treq, &headers) + transportHeadersToRequest(treq, headers) + if deleteReservedHeaders(headers) { + observability.IncReservedHeaderStripped(h.meter, call.CallerName(), call.ServiceName()) + } treq.Headers = headers if tcall, ok := call.(tchannelCall); ok { @@ -251,119 +255,6 @@ func (h handler) callHandler(ctx context.Context, call inboundCall, responseWrit } } -type handlerWriter struct { - failedWith error - format tchannel.Format - headers transport.Headers - buffer *bufferpool.Buffer - response inboundCallResponse - applicationError bool - headerCase headerCase -} - -func newHandlerWriter(response inboundCallResponse, format tchannel.Format, headerCase headerCase) responseWriter { - return &handlerWriter{ - response: response, - format: format, - headerCase: headerCase, - } -} - -func (hw *handlerWriter) AddHeaders(h transport.Headers) { - for k, v := range h.OriginalItems() { - if isReservedHeaderKey(k) { - hw.failedWith = appendError(hw.failedWith, fmt.Errorf("cannot use reserved header key: %s", k)) - return - } - hw.AddHeader(k, v) - } -} - -func (hw *handlerWriter) AddHeader(key string, value string) { - hw.headers = hw.headers.With(key, value) -} - -func (hw *handlerWriter) SetApplicationError() { - hw.applicationError = true -} - -func (hw *handlerWriter) SetApplicationErrorMeta(applicationErrorMeta *transport.ApplicationErrorMeta) { - if applicationErrorMeta == nil { - return - } - if applicationErrorMeta.Code != nil { - hw.AddHeader(ApplicationErrorCodeHeaderKey, strconv.Itoa(int(*applicationErrorMeta.Code))) - } - if applicationErrorMeta.Name != "" { - hw.AddHeader(ApplicationErrorNameHeaderKey, applicationErrorMeta.Name) - } - if applicationErrorMeta.Details != "" { - hw.AddHeader(ApplicationErrorDetailsHeaderKey, truncateAppErrDetails(applicationErrorMeta.Details)) - } -} - -func truncateAppErrDetails(val string) string { - if len(val) <= _maxAppErrDetailsHeaderLen { - return val - } - stripIndex := _maxAppErrDetailsHeaderLen - len(_truncatedHeaderMessage) - return val[:stripIndex] + _truncatedHeaderMessage -} - -func (hw *handlerWriter) IsApplicationError() bool { - return hw.applicationError -} - -func (hw *handlerWriter) Write(s []byte) (int, error) { - if hw.failedWith != nil { - return 0, hw.failedWith - } - - if hw.buffer == nil { - hw.buffer = bufferpool.Get() - } - - n, err := hw.buffer.Write(s) - if err != nil { - hw.failedWith = appendError(hw.failedWith, err) - } - return n, err -} - -func (hw *handlerWriter) Close() error { - retErr := hw.failedWith - if hw.IsApplicationError() { - if err := hw.response.SetApplicationError(); err != nil { - retErr = appendError(retErr, fmt.Errorf("SetApplicationError() failed: %v", err)) - } - } - - headers := headerMap(hw.headers, hw.headerCase) - retErr = appendError(retErr, writeHeaders(hw.format, headers, nil, hw.response.Arg2Writer)) - - // Arg3Writer must be opened and closed regardless of if there is data - // However, if there is a system error, we do not want to do this - bodyWriter, err := hw.response.Arg3Writer() - if err != nil { - return appendError(retErr, err) - } - defer func() { retErr = appendError(retErr, bodyWriter.Close()) }() - if hw.buffer != nil { - if _, err := hw.buffer.WriteTo(bodyWriter); err != nil { - return appendError(retErr, err) - } - } - - return retErr -} - -func (hw *handlerWriter) ReleaseBuffer() { - if hw.buffer != nil { - bufferpool.Put(hw.buffer) - hw.buffer = nil - } -} - func getSystemError(err error) error { if _, ok := err.(tchannel.SystemError); ok { return err diff --git a/transport/tchannel/handler_test.go b/transport/tchannel/handler_test.go index cae62a451..3386ef019 100644 --- a/transport/tchannel/handler_test.go +++ b/transport/tchannel/handler_test.go @@ -67,7 +67,7 @@ func TestHandlerErrors(t *testing.T) { format: tchannel.JSON, headers: []byte(`{"Rpc-Header-Foo": "bar"}`), wantHeaders: map[string]string{"rpc-header-foo": "bar"}, - newResponseWriter: newHandlerWriter, + newResponseWriter: newResponseWriter, recorder: newResponseRecorder(), }, { @@ -79,7 +79,7 @@ func TestHandlerErrors(t *testing.T) { 0x00, 0x03, 'B', 'a', 'r', // Bar }, wantHeaders: map[string]string{"foo": "Bar"}, - newResponseWriter: newHandlerWriter, + newResponseWriter: newResponseWriter, recorder: newResponseRecorder(), }, { @@ -199,7 +199,7 @@ func TestHandlerFailures(t *testing.T) { arg3: []byte{0x00}, }, wantStatus: tchannel.ErrCodeBadRequest, - newResponseWriter: newHandlerWriter, + newResponseWriter: newResponseWriter, recorder: newResponseRecorder(), wantLogLevel: zapcore.ErrorLevel, }, @@ -214,7 +214,7 @@ func TestHandlerFailures(t *testing.T) { arg3: []byte{0x00}, }, wantStatus: tchannel.ErrCodeBadRequest, - newResponseWriter: newHandlerWriter, + newResponseWriter: newResponseWriter, recorder: newResponseRecorder(), wantLogLevel: zapcore.ErrorLevel, }, @@ -229,7 +229,7 @@ func TestHandlerFailures(t *testing.T) { arg3: []byte{0x00}, }, wantStatus: tchannel.ErrCodeBadRequest, - newResponseWriter: newHandlerWriter, + newResponseWriter: newResponseWriter, recorder: newResponseRecorder(), wantLogLevel: zapcore.ErrorLevel, }, @@ -244,7 +244,7 @@ func TestHandlerFailures(t *testing.T) { arg3: nil, }, wantStatus: tchannel.ErrCodeUnexpected, - newResponseWriter: newHandlerWriter, + newResponseWriter: newResponseWriter, recorder: newResponseRecorder(), wantLogLevel: zapcore.ErrorLevel, }, @@ -274,7 +274,7 @@ func TestHandlerFailures(t *testing.T) { ).Return(fmt.Errorf("great sadness")) }, wantStatus: tchannel.ErrCodeUnexpected, - newResponseWriter: newHandlerWriter, + newResponseWriter: newResponseWriter, recorder: newResponseRecorder(), wantLogLevel: zapcore.ErrorLevel, }, @@ -307,7 +307,7 @@ func TestHandlerFailures(t *testing.T) { ))) }, wantStatus: tchannel.ErrCodeBadRequest, - newResponseWriter: newHandlerWriter, + newResponseWriter: newResponseWriter, recorder: newResponseRecorder(), wantLogLevel: zapcore.ErrorLevel, }, @@ -343,7 +343,7 @@ func TestHandlerFailures(t *testing.T) { }).Return(context.DeadlineExceeded) }, wantStatus: tchannel.ErrCodeTimeout, - newResponseWriter: newHandlerWriter, + newResponseWriter: newResponseWriter, recorder: newResponseRecorder(), wantLogLevel: zapcore.ErrorLevel, }, @@ -376,7 +376,7 @@ func TestHandlerFailures(t *testing.T) { }) }, wantStatus: tchannel.ErrCodeUnexpected, - newResponseWriter: newHandlerWriter, + newResponseWriter: newResponseWriter, recorder: newResponseRecorder(), wantLogLevel: zapcore.ErrorLevel, wantLogMessage: "Unary handler panicked", @@ -392,7 +392,7 @@ func TestHandlerFailures(t *testing.T) { arg3: []byte{0x00}, }, wantStatus: tchannel.ErrCodeBadRequest, - newResponseWriter: newHandlerWriter, + newResponseWriter: newResponseWriter, recorder: newFaultyResponseRecorder(), wantLogLevel: zapcore.ErrorLevel, wantLogMessage: "SendSystemError failed", @@ -580,7 +580,7 @@ func TestResponseWriter(t *testing.T) { resp := newResponseRecorder() call.resp = resp - w := newHandlerWriter(call.Response(), call.Format(), tt.headerCase) + w := newResponseWriter(call.Response(), call.Format(), tt.headerCase) tt.apply(w) assert.NoError(t, w.Close()) @@ -623,7 +623,7 @@ func TestResponseWriterFailure(t *testing.T) { resp := newResponseRecorder() tt.setupResp(resp) - w := newHandlerWriter(resp, tchannel.Raw, canonicalizedHeaderCase) + w := newResponseWriter(resp, tchannel.Raw, canonicalizedHeaderCase) _, err := w.Write([]byte("foo")) assert.NoError(t, err) _, err = w.Write([]byte("bar")) @@ -638,7 +638,7 @@ func TestResponseWriterFailure(t *testing.T) { func TestResponseWriterEmptyBodyHeaders(t *testing.T) { res := newResponseRecorder() - w := newHandlerWriter(res, tchannel.Raw, canonicalizedHeaderCase) + w := newResponseWriter(res, tchannel.Raw, canonicalizedHeaderCase) w.AddHeaders(transport.NewHeaders().With("foo", "bar")) require.NoError(t, w.Close()) @@ -696,7 +696,7 @@ func TestHandlerSystemErrorLogs(t *testing.T) { tchannelHandler := handler{ router: router, logger: zap.New(zapCore), - newResponseWriter: newHandlerWriter, + newResponseWriter: newResponseWriter, } router.EXPECT().Choose(gomock.Any(), gomock.Any()).Return(spec, nil).Times(4) @@ -806,7 +806,7 @@ func TestTruncatedHeader(t *testing.T) { } func TestRpcServiceHeader(t *testing.T) { - hw := &handlerWriter{} + hw := &responseWriterImpl{} h := handler{ headerCase: canonicalizedHeaderCase, newResponseWriter: func(inboundCallResponse, tchannel.Format, headerCase) responseWriter { diff --git a/transport/tchannel/header.go b/transport/tchannel/header.go index 1f80391b4..ab8dd38b3 100644 --- a/transport/tchannel/header.go +++ b/transport/tchannel/header.go @@ -68,11 +68,23 @@ var _reservedHeaderKeys = map[string]struct{}{ CallerProcedureHeader: {}, } +var ( + // enforceHeaderRules is a feature flag for a more strict error handling rules. + // See https://github.com/yarpc/yarpc-go/pull/2259 for more details. + enforceHeaderRules = false +) + +// isReservedHeaderKey checks header name by exact match. func isReservedHeaderKey(key string) bool { _, ok := _reservedHeaderKeys[strings.ToLower(key)] return ok } +// isReservedHeaderPrefix checks header name by prefix match. +func isReservedHeaderPrefix(header string) bool { + return strings.HasPrefix(strings.ToLower(header), "rpc-") || strings.HasPrefix(strings.ToLower(header), "$rpc$-") +} + // readRequestHeaders reads headers and baggage from an incoming request. func readRequestHeaders( ctx context.Context, @@ -177,19 +189,17 @@ func decodeHeaders(r io.Reader) (transport.Headers, error) { return headers, reader.Err() } -// headerCallerProcedureToRequest copies callerProcedure from headers to req.CallerProcedure -// and then deletes it from headers. -func headerCallerProcedureToRequest(req *transport.Request, headers *transport.Headers) *transport.Request { +// transportHeadersToRequest copies custom (which are not part of tchannel protocol) transport header values to request +// and then deletes them from headers list. +func transportHeadersToRequest(req *transport.Request, headers transport.Headers) { if callerProcedure, ok := headers.Get(CallerProcedureHeader); ok { req.CallerProcedure = callerProcedure headers.Del(CallerProcedureHeader) - return req } - return req } -// requestCallerProcedureToHeader add callerProcedure header as an application header. -func requestCallerProcedureToHeader(req *transport.Request, reqHeaders map[string]string) map[string]string { +// requestToTransportHeaders adds custom (which are not part of tchannel protocol) transport headers from request. +func requestToTransportHeaders(req *transport.Request, reqHeaders map[string]string) map[string]string { if req.CallerProcedure == "" { return reqHeaders } @@ -197,7 +207,9 @@ func requestCallerProcedureToHeader(req *transport.Request, reqHeaders map[strin if reqHeaders == nil { reqHeaders = make(map[string]string) } + reqHeaders[CallerProcedureHeader] = req.CallerProcedure + return reqHeaders } @@ -227,7 +239,7 @@ func encodeHeaders(hs map[string]string) []byte { return out } -func headerMap(hs transport.Headers, headerCase headerCase) map[string]string { +func getHeaderMap(hs transport.Headers, headerCase headerCase) map[string]string { switch headerCase { case originalHeaderCase: return hs.OriginalItems() @@ -236,10 +248,30 @@ func headerMap(hs transport.Headers, headerCase headerCase) map[string]string { } } -func deleteReservedHeaders(headers transport.Headers) { +func validateApplicationHeaders(headers map[string]string) error { + for key := range headers { + if isReservedHeaderPrefix(key) { + return yarpcerrors.InternalErrorf("header with rpc prefix is not allowed in request application headers (%s was passed)", key) + } + } + return nil +} + +func deleteReservedHeaders(headers transport.Headers) (reportHeader bool) { for headerKey := range _reservedHeaderKeys { headers.Del(headerKey) } + + for key := range headers.Items() { + if isReservedHeaderPrefix(key) { + reportHeader = true + if enforceHeaderRules { + headers.Del(key) + } + } + } + + return } // this check ensures that the service we're issuing a request to is the one diff --git a/transport/tchannel/header_test.go b/transport/tchannel/header_test.go index f8d686bea..4072940d5 100644 --- a/transport/tchannel/header_test.go +++ b/transport/tchannel/header_test.go @@ -100,7 +100,7 @@ func TestAddCallerProcedureHeader(t *testing.T) { }, } { t.Run(tt.desc, func(t *testing.T) { - headers := requestCallerProcedureToHeader(&tt.treq, tt.headers) + headers := requestToTransportHeaders(&tt.treq, tt.headers) assert.Equal(t, tt.expectedHeaders, headers) }) } @@ -134,8 +134,8 @@ func TestMoveCallerProcedureToRequest(t *testing.T) { } { t.Run(tt.desc, func(t *testing.T) { headers := transport.HeadersFromMap(tt.headers) - treq := headerCallerProcedureToRequest(&tt.treq, &headers) - assert.Equal(t, tt.expectedTreq, *treq) + transportHeadersToRequest(&tt.treq, headers) + assert.Equal(t, tt.expectedTreq, tt.treq) assert.Equal(t, transport.HeadersFromMap(tt.expectedHeaders), headers) }) } @@ -360,3 +360,121 @@ func TestValidateServiceHeaders(t *testing.T) { }) } } + +func TestDeleteReservedHeaders(t *testing.T) { + dp := map[string]struct { + headers map[string]string + enforceHeaderRule bool + expHeaders map[string]string + expReportHeaders bool + }{ + "nil-headers": {}, + "no-reserved-headers": { + headers: map[string]string{ + "any-header": "any-value", + }, + expHeaders: map[string]string{ + "any-header": "any-value", + }, + }, + "reserved-known-headers": { + headers: map[string]string{ + ServiceHeaderKey: "any-value", + "any-header": "any-value", + }, + expHeaders: map[string]string{ + "any-header": "any-value", + }, + }, + "reserved-rpc-headers": { + headers: map[string]string{ + "rpc-any": "any-value", + "any-header": "any-value", + }, + expHeaders: map[string]string{ + "rpc-any": "any-value", + "any-header": "any-value", + }, + expReportHeaders: true, + }, + "reserved-dollar-rpc-headers": { + headers: map[string]string{ + "$rpc$-any": "any-value", + "any-header": "any-value", + }, + expHeaders: map[string]string{ + "$rpc$-any": "any-value", + "any-header": "any-value", + }, + expReportHeaders: true, + }, + "enforce-header-rules": { + headers: map[string]string{ + "rpc-any": "any-value", + "$rpc$-any": "any-value", + "any-header": "any-value", + }, + enforceHeaderRule: true, + expHeaders: map[string]string{ + "any-header": "any-value", + }, + expReportHeaders: true, + }, + } + + for name, tt := range dp { + t.Run(name, func(t *testing.T) { + switchEnforceHeaderRules(t, tt.enforceHeaderRule) + + headers := transport.HeadersFromMap(tt.headers) + reportHeaders := deleteReservedHeaders(headers) + assert.Equal(t, tt.expReportHeaders, reportHeaders) + assert.Equal(t, transport.HeadersFromMap(tt.expHeaders), headers) + }) + + } +} + +func TestValidateApplicationHeaders(t *testing.T) { + dp := map[string]struct { + headers map[string]string + expErr error + }{ + "no-headers-no-error": {}, + "valid-headers-no-error": { + headers: map[string]string{ + "valid-key": "valid-value", + }, + }, + "reserved-rpc-header-error": { + headers: map[string]string{ + "rpc-any": "any-value", + }, + expErr: yarpcerrors.InternalErrorf("header with rpc prefix is not allowed in request application headers (rpc-any was passed)"), + }, + "reserved-dollad-rpc-header-error": { + headers: map[string]string{ + "$rpc$-any": "any-value", + }, + expErr: yarpcerrors.InternalErrorf("header with rpc prefix is not allowed in request application headers ($rpc$-any was passed)"), + }, + } + + for name, tt := range dp { + t.Run(name, func(t *testing.T) { + err := validateApplicationHeaders(tt.headers) + assert.Equal(t, tt.expErr, err) + }) + } +} + +func switchEnforceHeaderRules(t *testing.T, cond bool) { + if !cond { + return + } + + enforceHeaderRules = true + t.Cleanup(func() { + enforceHeaderRules = false + }) +} diff --git a/transport/tchannel/outbound.go b/transport/tchannel/outbound.go index 7676a9db3..2e7ee72ac 100644 --- a/transport/tchannel/outbound.go +++ b/transport/tchannel/outbound.go @@ -27,11 +27,13 @@ import ( "strconv" "github.com/uber/tchannel-go" + "go.uber.org/net/metrics" "go.uber.org/yarpc/api/peer" "go.uber.org/yarpc/api/transport" "go.uber.org/yarpc/api/x/introspection" "go.uber.org/yarpc/internal/bufferpool" "go.uber.org/yarpc/internal/iopool" + "go.uber.org/yarpc/internal/observability" intyarpcerrors "go.uber.org/yarpc/internal/yarpcerrors" peerchooser "go.uber.org/yarpc/peer" "go.uber.org/yarpc/peer/hostport" @@ -121,11 +123,11 @@ func (o *Outbound) Call(ctx context.Context, req *transport.Request) (*transport // Call sends an RPC to this specific peer. func (p *tchannelPeer) Call(ctx context.Context, req *transport.Request, reuseBuffer bool) (*transport.Response, error) { - return callWithPeer(ctx, req, p.getPeer(), p.transport.headerCase, reuseBuffer) + return callWithPeer(ctx, req, p.getPeer(), p.transport.headerCase, reuseBuffer, p.transport.meter) } // callWithPeer sends a request with the chosen peer. -func callWithPeer(ctx context.Context, req *transport.Request, peer *tchannel.Peer, headerCase headerCase, reuseBuffer bool) (*transport.Response, error) { +func callWithPeer(ctx context.Context, req *transport.Request, peer *tchannel.Peer, headerCase headerCase, reuseBuffer bool, meter *metrics.Scope) (*transport.Response, error) { // NB(abg): Under the current API, the local service's name is required // twice: once when constructing the TChannel and then again when // constructing the RPC. @@ -153,14 +155,20 @@ func callWithPeer(ctx context.Context, req *transport.Request, peer *tchannel.Pe req.Procedure, &callOptions, ) - if err != nil { return nil, err } - reqHeaders := headerMap(req.Headers, headerCase) - // for tchannel, callerProcedure is added to application headers. - reqHeaders = requestCallerProcedureToHeader(req, reqHeaders) + reqHeaders := getHeaderMap(req.Headers, headerCase) + + if err := validateApplicationHeaders(reqHeaders); err != nil { + observability.IncReservedHeaderError(meter, req.Caller, req.Service) + if enforceHeaderRules { + return nil, err + } + } + + reqHeaders = requestToTransportHeaders(req, reqHeaders) // baggage headers are transport implementation details that are stripped out (and stored in the context). Users don't interact with it tracingBaggage := tchannel.InjectOutboundSpan(call.Response(), nil) @@ -212,7 +220,9 @@ func callWithPeer(ctx context.Context, req *transport.Request, peer *tchannel.Pe applicationErrorDetails, _ := headers.Get(ApplicationErrorDetailsHeaderKey) err = getResponseError(headers) - deleteReservedHeaders(headers) + if deleteReservedHeaders(headers) { + observability.IncReservedHeaderStripped(meter, req.Caller, req.Service) + } resp := &transport.Response{ Headers: headers, diff --git a/transport/tchannel/response_writer.go b/transport/tchannel/response_writer.go new file mode 100644 index 000000000..f58c2a273 --- /dev/null +++ b/transport/tchannel/response_writer.go @@ -0,0 +1,162 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package tchannel + +import ( + "fmt" + "strconv" + + "github.com/uber/tchannel-go" + "go.uber.org/yarpc/api/transport" + "go.uber.org/yarpc/internal/bufferpool" +) + +type responseWriterImpl struct { + failedWith error + format tchannel.Format + headers transport.Headers + buffer *bufferpool.Buffer + response inboundCallResponse + applicationError bool + headerCase headerCase + reservedHeader bool +} + +func newResponseWriter(response inboundCallResponse, format tchannel.Format, headerCase headerCase) responseWriter { + return &responseWriterImpl{ + response: response, + format: format, + headerCase: headerCase, + } +} + +func (hw *responseWriterImpl) AddHeaders(h transport.Headers) { + for k, v := range h.OriginalItems() { + if !isReservedHeaderPrefix(k) { + hw.addHeader(k, v) + continue + } + + hw.reservedHeader = true + if enforceHeaderRules { + hw.failedWith = appendError(hw.failedWith, fmt.Errorf("header with rpc prefix is not allowed in response application headers (%s was passed)", k)) + return + } else if isReservedHeaderKey(k) { + hw.failedWith = appendError(hw.failedWith, fmt.Errorf("cannot use reserved header key: %s", k)) + return + } else { + hw.addHeader(k, v) + } + } +} + +func (hw *responseWriterImpl) AddSystemHeader(key, value string) { + hw.addHeader(key, value) +} + +func (hw *responseWriterImpl) addHeader(key, value string) { + hw.headers = hw.headers.With(key, value) +} + +func (hw *responseWriterImpl) SetApplicationError() { + hw.applicationError = true +} + +func (hw *responseWriterImpl) SetApplicationErrorMeta(applicationErrorMeta *transport.ApplicationErrorMeta) { + if applicationErrorMeta == nil { + return + } + if applicationErrorMeta.Code != nil { + hw.AddSystemHeader(ApplicationErrorCodeHeaderKey, strconv.Itoa(int(*applicationErrorMeta.Code))) + } + if applicationErrorMeta.Name != "" { + hw.AddSystemHeader(ApplicationErrorNameHeaderKey, applicationErrorMeta.Name) + } + if applicationErrorMeta.Details != "" { + hw.AddSystemHeader(ApplicationErrorDetailsHeaderKey, truncateAppErrDetails(applicationErrorMeta.Details)) + } +} + +func truncateAppErrDetails(val string) string { + if len(val) <= _maxAppErrDetailsHeaderLen { + return val + } + stripIndex := _maxAppErrDetailsHeaderLen - len(_truncatedHeaderMessage) + return val[:stripIndex] + _truncatedHeaderMessage +} + +func (hw *responseWriterImpl) IsApplicationError() bool { + return hw.applicationError +} + +func (hw *responseWriterImpl) Write(s []byte) (int, error) { + if hw.failedWith != nil { + return 0, hw.failedWith + } + + if hw.buffer == nil { + hw.buffer = bufferpool.Get() + } + + n, err := hw.buffer.Write(s) + if err != nil { + hw.failedWith = appendError(hw.failedWith, err) + } + return n, err +} + +func (hw *responseWriterImpl) Close() error { + retErr := hw.failedWith + if hw.IsApplicationError() { + if err := hw.response.SetApplicationError(); err != nil { + retErr = appendError(retErr, fmt.Errorf("SetApplicationError() failed: %v", err)) + } + } + + headers := getHeaderMap(hw.headers, hw.headerCase) + retErr = appendError(retErr, writeHeaders(hw.format, headers, nil, hw.response.Arg2Writer)) + + // Arg3Writer must be opened and closed regardless of if there is data + // However, if there is a system error, we do not want to do this + bodyWriter, err := hw.response.Arg3Writer() + if err != nil { + return appendError(retErr, err) + } + defer func() { retErr = appendError(retErr, bodyWriter.Close()) }() + if hw.buffer != nil { + if _, err := hw.buffer.WriteTo(bodyWriter); err != nil { + return appendError(retErr, err) + } + } + + return retErr +} + +func (hw *responseWriterImpl) ReleaseBuffer() { + if hw.buffer != nil { + bufferpool.Put(hw.buffer) + hw.buffer = nil + } +} + +func (hw *responseWriterImpl) IsReservedHeaderUsed() bool { + return hw.reservedHeader +} diff --git a/transport/tchannel/response_writer_test.go b/transport/tchannel/response_writer_test.go new file mode 100644 index 000000000..c24dccd45 --- /dev/null +++ b/transport/tchannel/response_writer_test.go @@ -0,0 +1,80 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package tchannel + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/multierr" + "go.uber.org/yarpc/api/transport" +) + +func TestResponseWriterAddHeaders(t *testing.T) { + dp := map[string]struct { + h transport.Headers + enforceHeaderRules bool + expErr error + expReservedHeader bool + expHeaders transport.Headers + }{ + "success": { + h: transport.NewHeaders().With("foo", "bar"), + expHeaders: transport.NewHeaders().With("foo", "bar"), + }, + "known-reserved-header-used-which-lead-to-error": { + h: transport.NewHeaders().With(ServiceHeaderKey, "any-value"), + expErr: fmt.Errorf("cannot use reserved header key: %s", ServiceHeaderKey), + expReservedHeader: true, + }, + "unknown-reserved-header-used-which-lead-reporting-metric": { + h: transport.NewHeaders().With("rpc-any", "any-value"), + expHeaders: transport.NewHeaders().With("rpc-any", "any-value"), + expReservedHeader: true, + }, + "enforce-header-rules": { + h: transport.NewHeaders().With("rpc-any", "any-value"), + enforceHeaderRules: true, + expErr: fmt.Errorf("header with rpc prefix is not allowed in response application headers (rpc-any was passed)"), + expReservedHeader: true, + }, + } + + for name, tt := range dp { + t.Run(name, func(t *testing.T) { + switchEnforceHeaderRules(t, tt.enforceHeaderRules) + rw := responseWriterImpl{} + + rw.AddHeaders(tt.h) + if tt.expErr != nil { + errs := multierr.Errors(rw.failedWith) + require.Len(t, errs, 1) + assert.Equal(t, tt.expErr, errs[0]) + } else { + assert.NoError(t, rw.failedWith) + } + assert.Equal(t, tt.expReservedHeader, rw.reservedHeader) + assert.Equal(t, tt.expHeaders, rw.headers) + }) + } +} diff --git a/transport/tchannel/tchannel_utils_test.go b/transport/tchannel/tchannel_utils_test.go index bed8866fb..a69dbfa22 100644 --- a/transport/tchannel/tchannel_utils_test.go +++ b/transport/tchannel/tchannel_utils_test.go @@ -171,7 +171,7 @@ func (fr *faultyResponseRecorder) SendSystemError(err error) error { // faultyHandlerWriter mocks a responseWriter.Close() error to test logging behaviour // inside tchannel.Handle. -type faultyHandlerWriter struct{ handlerWriter } +type faultyHandlerWriter struct{ responseWriterImpl } func newFaultyHandlerWriter(response inboundCallResponse, format tchannel.Format, headerCase headerCase) responseWriter { return &faultyHandlerWriter{} diff --git a/transport/tchannel/transport.go b/transport/tchannel/transport.go index b90e6db2e..8047d9895 100644 --- a/transport/tchannel/transport.go +++ b/transport/tchannel/transport.go @@ -130,7 +130,7 @@ func (o transportOptions) newTransport() *Transport { logger: logger, meter: o.meter, headerCase: headerCase, - newResponseWriter: newHandlerWriter, + newResponseWriter: newResponseWriter, nativeTChannelMethods: o.nativeTChannelMethods, excludeServiceHeaderInResponse: o.excludeServiceHeaderInResponse, inboundTLSConfig: o.inboundTLSConfig, @@ -225,6 +225,7 @@ func (t *Transport) start() error { tracer: t.tracer, headerCase: t.headerCase, logger: t.logger, + meter: t.meter, newResponseWriter: t.newResponseWriter, excludeServiceHeaderInResponse: t.excludeServiceHeaderInResponse, },