Skip to content

Commit

Permalink
[DEMO] preview of redesign tracing instrumentation layer
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenX1993 committed Sep 28, 2024
1 parent ed351c6 commit 3a394bb
Show file tree
Hide file tree
Showing 20 changed files with 706 additions and 137 deletions.
99 changes: 99 additions & 0 deletions api/transport/propagation.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,20 @@ package transport

import (
"context"
"strings"
"sync"
"time"

"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
opentracinglog "github.com/opentracing/opentracing-go/log"
)

const (
tchannelTracingKeyPrefix = "$tracing$"
tchannelTracingKeyMappingSize = 100
)

// CreateOpenTracingSpan creates a new context with a started span
type CreateOpenTracingSpan struct {
Tracer opentracing.Tracer
Expand Down Expand Up @@ -119,3 +126,95 @@ func UpdateSpanWithErr(span opentracing.Span, err error) error {
}
return err
}

// GetPropagationFormat returns the opentracing propagation depends on transport.
// For TChannel, the format is opentracing.TextMap
// For HTTP and gRPC, the format is opentracing.HTTPHeaders
func GetPropagationFormat(transport string) opentracing.BuiltinFormat {
if transport == "tchannel" {
return opentracing.TextMap
}
return opentracing.HTTPHeaders
}

// PropagationCarrier is an interface to combine both reader and writer interface
type PropagationCarrier interface {
opentracing.TextMapReader
opentracing.TextMapWriter
}

// GetPropagationCarrier get the propagation carrier depends on the transport.
// The carrier is used for accessing the transport headers.
// For TChannel, a special carrier is used. For details, see comments of TChannelHeadersCarrier
func GetPropagationCarrier(headers map[string]string, transport string) PropagationCarrier {
if transport == "tchannel" {
return TChannelHeadersCarrier(headers)
}
return opentracing.TextMapCarrier(headers)
}

// TChannelHeadersCarrier is a dedicated carrier for TChannel.
// When writing the tracing headers into headers, the $tracing$ prefix is added to each tracing header key.
// When reading the tracing headers from headers, the $tracing$ prefix is removed from each tracing header key.
type TChannelHeadersCarrier map[string]string

var _ PropagationCarrier = TChannelHeadersCarrier{}

func (c TChannelHeadersCarrier) ForeachKey(handler func(string, string) error) error {
for k, v := range c {
if !strings.HasPrefix(k, tchannelTracingKeyPrefix) {
continue
}
noPrefixKey := tchannelTracingKeyDecoding.mapAndCache(k)
if err := handler(noPrefixKey, v); err != nil {
return err
}
}
return nil
}

func (c TChannelHeadersCarrier) Set(key, value string) {
prefixedKey := tchannelTracingKeyEncoding.mapAndCache(key)
c[prefixedKey] = value
}

// tchannelTracingKeysMapping is to optimize the efficiency of tracing header key manipulations.
// The implementation is forked from tchannel-go: https://github.com/uber/tchannel-go/blob/dev/tracing_keys.go#L36
type tchannelTracingKeysMapping struct {
sync.RWMutex
mapping map[string]string
mapper func(key string) string
}

var tchannelTracingKeyEncoding = &tchannelTracingKeysMapping{
mapping: make(map[string]string),
mapper: func(key string) string {
return tchannelTracingKeyPrefix + key
},
}

var tchannelTracingKeyDecoding = &tchannelTracingKeysMapping{
mapping: make(map[string]string),
mapper: func(key string) string {
return key[len(tchannelTracingKeyPrefix):]
},
}

func (m *tchannelTracingKeysMapping) mapAndCache(key string) string {
m.RLock()
v, ok := m.mapping[key]
m.RUnlock()
if ok {
return v
}
m.Lock()
defer m.Unlock()
if v, ok := m.mapping[key]; ok {
return v
}
mappedKey := m.mapper(key)
if len(m.mapping) < tchannelTracingKeyMappingSize {
m.mapping[key] = mappedKey
}
return mappedKey
}
264 changes: 264 additions & 0 deletions internal/tracingmiddleware/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
// Copyright (c) 2024 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package tracingmiddleware

import (
"context"
"time"

"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"github.com/opentracing/opentracing-go/log"
"go.uber.org/yarpc/api/middleware"
"go.uber.org/yarpc/api/transport"
"go.uber.org/yarpc/yarpcerrors"
)

var (
_ middleware.UnaryInbound = (*Middleware)(nil)
_ middleware.UnaryOutbound = (*Middleware)(nil)
_ middleware.OnewayInbound = (*Middleware)(nil)
_ middleware.OnewayOutbound = (*Middleware)(nil)
_ middleware.StreamInbound = (*Middleware)(nil)
_ middleware.StreamOutbound = (*Middleware)(nil)
)

type (
unaryHandlerFunc func(context.Context, *transport.Request, transport.ResponseWriter) error
unaryOutboundFunc func(context.Context, *transport.Request) (*transport.Response, error)
onewayHandlerFunc func(context.Context, *transport.Request) error
onewayOutboundFunc func(context.Context, *transport.Request) (transport.Ack, error)
streamHandlerFunc func(*transport.ServerStream) error
streamOutboundFunc func(context.Context, *transport.StreamRequest) (*transport.ClientStream, error)
)

// Params defines the parameters for creating the Middleware
type Params struct {
Tracer opentracing.Tracer
Transport string
}

// Middleware is the tracing middleware for all RPC types.
// It handles both observability and inter-process context propagation.
type Middleware struct {
tracer opentracing.Tracer
transport string
propagationFormat opentracing.BuiltinFormat

unaryInbound func(context.Context, *transport.Request, transport.ResponseWriter, unaryHandlerFunc) error
unaryOutbound func(context.Context, *transport.Request, unaryOutboundFunc) (*transport.Response, error)
onewayInbound func(context.Context, *transport.Request, onewayHandlerFunc) error
onewayOutbound func(context.Context, *transport.Request, onewayOutboundFunc) (transport.Ack, error)
streamInbound func(*transport.ServerStream, streamHandlerFunc) error
streamOutbound func(context.Context, *transport.StreamRequest, streamOutboundFunc) (*transport.ClientStream, error)
}

// New constructs a tracing middleware with the provided configuration.
func New(p Params) *Middleware {
m := &Middleware{
tracer: p.Tracer,
transport: p.Transport,
propagationFormat: transport.GetPropagationFormat(p.Transport),
}
if m.tracer == nil {
m.tracer = opentracing.GlobalTracer()
}

m.unaryInbound = m.handle
m.unaryOutbound = m.call
m.onewayInbound = m.handleOneway
m.onewayOutbound = m.callOneway
m.streamInbound = m.handleStream
m.streamOutbound = m.callStream

return m
}

// NewNop return a no-op tracing middleware
func NewNop() *Middleware {
return &Middleware{
tracer: opentracing.NoopTracer{},
unaryInbound: func(ctx context.Context, req *transport.Request, resw transport.ResponseWriter, handle unaryHandlerFunc) error {
return handle(ctx, req, resw)
},
unaryOutbound: func(ctx context.Context, req *transport.Request, call unaryOutboundFunc) (*transport.Response, error) {
return call(ctx, req)
},
onewayInbound: func(ctx context.Context, req *transport.Request, handleOneway onewayHandlerFunc) error {
return handleOneway(ctx, req)
},
onewayOutbound: func(ctx context.Context, req *transport.Request, callOneway onewayOutboundFunc) (transport.Ack, error) {
return callOneway(ctx, req)
},
streamInbound: func(srv *transport.ServerStream, handleStream streamHandlerFunc) error {
return handleStream(srv)
},
streamOutbound: func(ctx context.Context, req *transport.StreamRequest, callStream streamOutboundFunc) (*transport.ClientStream, error) {
return callStream(ctx, req)
},
}
}

func (m *Middleware) Handle(ctx context.Context, req *transport.Request, resw transport.ResponseWriter, h transport.UnaryHandler) error {
return m.unaryInbound(ctx, req, resw, h.Handle)
}

func (m *Middleware) Call(ctx context.Context, req *transport.Request, out transport.UnaryOutbound) (*transport.Response, error) {
return m.unaryOutbound(ctx, req, out.Call)
}

func (m *Middleware) InterceptUnaryOutbound(ctx context.Context, req *transport.Request, call unaryOutboundFunc) (*transport.Response, error) {
return m.unaryOutbound(ctx, req, call)
}

func (m *Middleware) HandleOneway(ctx context.Context, req *transport.Request, h transport.OnewayHandler) error {
return m.onewayInbound(ctx, req, h.HandleOneway)
}

func (m *Middleware) CallOneway(ctx context.Context, request *transport.Request, out transport.OnewayOutbound) (transport.Ack, error) {
return m.onewayOutbound(ctx, request, out.CallOneway)
}

func (m *Middleware) InterceptOnewayOutbound(ctx context.Context, request *transport.Request, callOneway onewayOutboundFunc) (transport.Ack, error) {
return m.onewayOutbound(ctx, request, callOneway)
}

func (m *Middleware) HandleStream(s *transport.ServerStream, h transport.StreamHandler) error {
return m.streamInbound(s, h.HandleStream)
}

func (m *Middleware) CallStream(ctx context.Context, req *transport.StreamRequest, out transport.StreamOutbound) (*transport.ClientStream, error) {
return m.streamOutbound(ctx, req, out.CallStream)
}

func (m *Middleware) InterceptStreamOutbound(ctx context.Context, req *transport.StreamRequest, callStream streamOutboundFunc) (*transport.ClientStream, error) {
return m.streamOutbound(ctx, req, callStream)
}

func (m *Middleware) handle(ctx context.Context, req *transport.Request, resw transport.ResponseWriter, handle unaryHandlerFunc) error {
parentSpanCtx, _ := m.tracer.Extract(m.propagationFormat, transport.GetPropagationCarrier(req.Headers.Items(), req.Transport))
extractOpenTracingSpan := &transport.ExtractOpenTracingSpan{
ParentSpanContext: parentSpanCtx,
Tracer: m.tracer,
TransportName: req.Transport,
StartTime: time.Now(),
// circular dependencies - we need to relocate the tracing tags
// ExtraTags: yarpc.OpentracingTags,
}
ctx, span := extractOpenTracingSpan.Do(ctx, req)
defer span.Finish()

err := handle(ctx, req, resw)
return updateSpanWithError(span, err)
}

func (m *Middleware) call(ctx context.Context, req *transport.Request, call unaryOutboundFunc) (*transport.Response, error) {
createOpenTracingSpan := &transport.CreateOpenTracingSpan{
Tracer: m.tracer,
TransportName: m.transport,
StartTime: time.Now(),
// circular dependencies - we need to relocate the tracing tags
//ExtraTags: yarpc.OpentracingTags
}
ctx, span := createOpenTracingSpan.Do(ctx, req)
defer span.Finish()

tracingHeaders := make(map[string]string)
if err := m.tracer.Inject(span.Context(), m.propagationFormat, transport.GetPropagationCarrier(tracingHeaders, m.transport)); err != nil {
ext.Error.Set(span, true)
span.LogFields(log.String("event", "error"), log.String("message", err.Error()))
return nil, err
}
for k, v := range tracingHeaders {
req.Headers = req.Headers.With(k, v)
}

res, err := call(ctx, req)
return res, updateSpanWithOutboundError(span, res, err)
}

func (m *Middleware) handleOneway(ctx context.Context, req *transport.Request, handleOneway onewayHandlerFunc) error {
// TODO implement me
panic("implement me")
}

func (m *Middleware) callOneway(ctx context.Context, request *transport.Request, callOneway onewayOutboundFunc) (transport.Ack, error) {
// TODO implement me
panic("implement me")
}

func (m *Middleware) handleStream(s *transport.ServerStream, handleStream streamHandlerFunc) error {
// TODO implement me
panic("implement me")
}

func (m *Middleware) callStream(ctx context.Context, req *transport.StreamRequest, callStream streamOutboundFunc) (*transport.ClientStream, error) {
// TODO implement me
// client stream is a bit more complex, as we need to intercept the clientStream.
// We can refer yarpc its own implementation: https://github.com/yarpc/yarpc-go/blob/dev/transport/grpc/stream.go#L103
// or opentracing contrib implementation: https://github.com/opentracing-contrib/go-grpc/blob/master/client.go#L131
panic("implement me")
}

func updateSpanWithError(span opentracing.Span, err error) error {
if err == nil {
return err
}

ext.Error.Set(span, true)
if yarpcerrors.IsStatus(err) {
status := yarpcerrors.FromError(err)
errCode := status.Code()
span.SetTag("rpc.yarpc.status_code", errCode.String())
span.SetTag("error.type", errCode.String())
return err
}

span.SetTag("error.type", "unknown_internal_yarpc")
return err
}

func updateSpanWithOutboundError(span opentracing.Span, res *transport.Response, err error) error {
isApplicationError := false
if res != nil {
isApplicationError = res.ApplicationError
}
if err == nil && !isApplicationError {
return err
}

ext.Error.Set(span, true)
if yarpcerrors.IsStatus(err) {
status := yarpcerrors.FromError(err)
errCode := status.Code()
span.SetTag("rpc.yarpc.status_code", errCode.String())
span.SetTag("error.type", errCode.String())
return err
}

if isApplicationError {
span.SetTag("error.type", "application_error")
return err
}

span.SetTag("error.type", "unknown_internal_yarpc")
return err
}
Loading

0 comments on commit 3a394bb

Please sign in to comment.