forked from denisenkom/go-mssqldb
-
Notifications
You must be signed in to change notification settings - Fork 68
/
Copy pathsession.go
100 lines (87 loc) · 3.1 KB
/
session.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
package mssql
import (
"context"
"fmt"
"github.com/google/uuid"
"github.com/microsoft/go-mssqldb/aecmk"
"github.com/microsoft/go-mssqldb/msdsn"
)
func newSession(outbuf *tdsBuffer, logger ContextLogger, p msdsn.Config) *tdsSession {
sess := &tdsSession{
buf: outbuf,
logger: logger,
logFlags: uint64(p.LogFlags),
aeSettings: &alwaysEncryptedSettings{keyProviders: aecmk.GetGlobalCekProviders()},
}
_ = sess.activityid.Scan(p.ActivityID)
// generating a guid has a small chance of failure. Make a best effort
connid, cerr := uuid.NewRandom()
if cerr == nil {
_ = sess.connid.Scan(connid[:])
}
return sess
}
func (s *tdsSession) preparePreloginFields(ctx context.Context, p msdsn.Config, fe *featureExtFedAuth) map[uint8][]byte {
instance_buf := []byte(p.Instance)
instance_buf = append(instance_buf, 0) // zero terminate instance name
var encrypt byte
switch p.Encryption {
default:
panic(fmt.Errorf("Unsupported Encryption Config %v", p.Encryption))
case msdsn.EncryptionDisabled:
encrypt = encryptNotSup
case msdsn.EncryptionRequired:
encrypt = encryptOn
case msdsn.EncryptionOff:
encrypt = encryptOff
case msdsn.EncryptionStrict:
encrypt = encryptStrict
}
v := getDriverVersion(driverVersion)
fields := map[uint8][]byte{
// 4 bytes for version and 2 bytes for minor version
preloginVERSION: {byte(v), byte(v >> 8), byte(v >> 16), byte(v >> 24), 0, 0},
preloginENCRYPTION: {encrypt},
preloginINSTOPT: instance_buf,
preloginTHREADID: {0, 0, 0, 0},
preloginMARS: {0}, // MARS disabled
}
if !p.NoTraceID {
traceID := make([]byte, 36) // 16 byte connection id + 16 byte activity id + 4 byte sequence number
connid, _ := s.connid.Value()
activityid, _ := s.activityid.Value()
_ = copy(traceID[:16], connid.([]byte))
_ = copy(traceID[16:32], activityid.([]byte))
fields[preloginTRACEID] = traceID
if (s.logFlags)&logDebug != 0 {
msg := fmt.Sprintf("Creating prelogin packet with connection id '%s' and activity id '%s'", s.connid, s.activityid)
s.logger.Log(ctx, msdsn.LogDebug, msg)
}
}
if fe.FedAuthLibrary != FedAuthLibraryReserved {
fields[preloginFEDAUTHREQUIRED] = []byte{1}
}
return fields
}
type logFunc func() string
func (s *tdsSession) logPrefix() string {
if s.logFlags&uint64(msdsn.LogSessionIDs) != 0 {
return fmt.Sprintf("aid:%v cid:%v - ", s.activityid, s.connid)
}
return ""
}
func (s *tdsSession) LogS(ctx context.Context, category msdsn.Log, msg string) {
s.Log(ctx, category, func() string { return msg })
}
// Log checks that the session logFlags includes the category before evaluating the logFunc and emitting the trace
func (s *tdsSession) Log(ctx context.Context, category msdsn.Log, logFunc logFunc) {
if s.logFlags&uint64(category) != 0 {
s.logger.Log(ctx, category, s.logPrefix()+logFunc())
}
}
// LogF checks that the session logFlags includes the category before calling fmt.Sprintf and emitting the trace
func (s *tdsSession) LogF(ctx context.Context, category msdsn.Log, format string, a ...any) {
if s.logFlags&uint64(category) != 0 {
s.logger.Log(ctx, category, s.logPrefix()+fmt.Sprintf(format, a...))
}
}