-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsession_handler.go
201 lines (174 loc) · 6.45 KB
/
session_handler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
package mgohttp
import (
"context"
"fmt"
"net/http"
"runtime"
"strings"
"sync"
"time"
"github.com/Clever/mgohttp/internal"
opentracing "github.com/opentracing/opentracing-go"
ext "github.com/opentracing/opentracing-go/ext"
"gopkg.in/Clever/kayvee-go.v6/logger"
mgo "gopkg.in/mgo.v2"
)
// SessionHandlerConfig dictates how we inject mongo sessions into the context
// of the HTTP request.
type SessionHandlerConfig struct {
Sess *mgo.Session
Database string
Timeout time.Duration
Handler http.Handler
}
type mgoSessionCopier interface {
Copy() *mgo.Session
}
// SessionHandler is an HTTP middleware that injects a new copied mongo session
// into the Context of the request.
// This middleware handles timing out inflight Mongo requests.
type SessionHandler struct {
parentSession mgoSessionCopier
database string
timeout time.Duration
handler http.Handler
errorCode int // this is defaulted to 503, only the tests can override
}
// NewSessionHandler returns a new MongoSessionInjector which implements http.HandlerFunc
func NewSessionHandler(cfg SessionHandlerConfig) http.Handler {
return &SessionHandler{
database: cfg.Database,
parentSession: cfg.Sess,
timeout: cfg.Timeout,
handler: cfg.Handler,
errorCode: http.StatusServiceUnavailable,
}
}
// getCallerName retrieves the name of the calling function.
// rough source: https://golang.org/pkg/runtime/#example_Frames
func getCallerName() string {
// Ask runtime.Callers for up to 10 pcs, including runtime.Callers itself.
pc := make([]uintptr, 10)
n := runtime.Callers(0, pc)
if n == 0 {
// No pcs available. Stop now.
// This can happen if the first argument to runtime.Callers is large.
return ""
}
pc = pc[:n] // pass only valid pcs to runtime.CallersFrames
frames := runtime.CallersFrames(pc)
// Loop to get frames.
// A fixed number of pcs can expand to an indefinite number of Frames.
for {
frame, more := frames.Next()
if strings.Contains(frame.Function, "mgohttp") || strings.Contains(frame.Function, "runtime") {
continue
} else if !more {
break
}
return frame.Function
}
return "mgohttp-default-fn"
}
// ServeHTTP injects a "getter" to the HTTP request context that allows any wrapped hTTP handler
// to retrieve a new database connection
func (c *SessionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Instantiate the nil session and timer objects that may be lazily instantiated if
// the request handler asks for a session.
var newSession *mgo.Session
sessionMutex := sync.Mutex{}
sessionTimer := time.NewTimer(c.timeout)
ctx := r.Context()
var libSpan, sp opentracing.Span
// At the end, if we instantiated a session (and inherently a tracing span), close/finish
// them to clean up.
defer func() {
sessionMutex.Lock()
defer sessionMutex.Unlock()
if newSession != nil {
newSession.Close()
// if we didn't open a session, we don't care about closing the spans
sp.Finish()
libSpan.Finish()
}
}()
// Create a timeoutWriter to avoid races on the http.ResponseWriter.
tw := &timeoutWriter{
w: w,
h: make(http.Header),
}
// getSession is injected into the Context, repeated calls by the same request will return
// the same session.
var getSession internal.SessionGetter = func(ctx context.Context) (*mgo.Session, context.Context) {
// we've already created a session for this request, shortcircuit and return that session.
if newSession != nil {
// close the prior span & open a new one
sp.Finish()
sp, ctx = opentracing.StartSpanFromContext(ctx, getCallerName())
return newSession, ctx
}
libSpan, ctx = opentracing.StartSpanFromContext(ctx, "mgohttp")
// set the service as the database - this will convey that it is a dependency of the service
ext.PeerService.Set(libSpan, c.database)
ext.SpanKind.Set(libSpan, ext.SpanKindRPCClientEnum)
ext.Component.Set(libSpan, "mgohttp")
ext.DBType.Set(libSpan, "mongodb")
sp, ctx = opentracing.StartSpanFromContext(ctx, getCallerName())
sessionMutex.Lock()
defer sessionMutex.Unlock()
// Create a session copy. We prefer Copy over Clone because opening new sockets
// allows for greater throughput to the database.
// Sessions created using Clone queue all requests through the parent connection's
// socket. This creates a slow bottleneck when expensive queries appear.
// NOTE: consider allowing the consumer to pass in a "newSession" function of
// `func() *mgo.Session` if we are pressed for more flexibility here.
newSession = c.parentSession.Copy()
// SetSocketTimeout guarantees that no individual query to mongo can take longer than
// the RequestTimeoutDuration value.
newSession.SetSocketTimeout(c.timeout)
return newSession, ctx
}
done := make(chan struct{}) // done signifies the end of the HTTP request when closed
go func() {
defer func() {
// If the SessionHandler timeout is hit, we close the mgo session. But server handler
// code may continue executing (even if the server timeout is the same as the
// SessionHandler timeout). If another DB operation is attempted, mgo will panic with a
// "Session already closed" error. Let's catch these panics to prevent server crashes.
if err := recover(); err != nil {
if err != "Session already closed" {
panic(err)
}
logger.FromContext(r.Context()).Error("mgo-session-already-closed-panic-caught")
}
}()
// amend the request context with the database connection then serve the wrapped
// HTTP handler
newCtx := internal.NewContext(ctx, c.database, getSession)
c.handler.ServeHTTP(tw, r.WithContext(newCtx))
close(done)
}()
// this select guarantees that we only write to the ResponseWriter a single time
select {
case <-done:
// If we served the request without being preempted by the timer, copy over all the
// writes from the timeout handler to the actual http.ResponseWriter.
tw.copyToResponseWriter(w)
case <-sessionTimer.C:
tw.setTimedOut()
w.WriteHeader(c.errorCode)
logger.FromContext(r.Context()).Error("mongo-session-killed")
}
}
// FromContext retrieves a *mgo.Session from the request context.
func FromContext(ctx context.Context, database string) MongoSession {
getSessionBlob := ctx.Value(internal.GetMgoSessionKey(database))
if getSession, ok := getSessionBlob.(internal.SessionGetter); ok {
sess, ctx := getSession(ctx)
return tracedMgoSession{
sess: sess,
ctx: ctx,
}
}
panic(fmt.Sprintf("SessionFromContext must receive a valid database name: %s not found", database))
}