-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathconnection.go
136 lines (112 loc) · 2.96 KB
/
connection.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package azuretls
import (
"context"
"crypto/x509"
"errors"
tls "github.com/Noooste/utls"
"net"
"time"
)
func (s *Session) dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := s.dial(ctx, network, addr)
if err != nil {
return nil, errors.New("failed to dial: " + err.Error())
}
return s.upgradeTLS(ctx, conn, addr)
}
func (s *Session) dial(ctx context.Context, network, addr string) (net.Conn, error) {
if s.ProxyDialer != nil {
var userAgent = s.UserAgent
if ctx.Value(userAgentKey) != nil {
userAgent = ctx.Value(userAgentKey).(string)
}
return s.ProxyDialer.DialContext(ctx, userAgent, network, addr)
}
dialer := &net.Dialer{
Timeout: s.TimeOut,
KeepAlive: 30 * time.Second,
}
if s.ModifyDialer != nil {
if err := s.ModifyDialer(dialer); err != nil {
return nil, err
}
}
return dialer.DialContext(ctx, network, addr)
}
func (s *Session) upgradeTLS(ctx context.Context, conn net.Conn, addr string) (net.Conn, error) {
// Split addr and port
hostname, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, errors.New("failed to split addr and port: " + err.Error())
}
if !s.InsecureSkipVerify {
if err = s.Pin(addr); err != nil {
return nil, errors.New("failed to pin: " + err.Error())
}
}
config := tls.Config{
ServerName: hostname,
InsecureSkipVerify: s.InsecureSkipVerify,
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
if s.InsecureSkipVerify {
return nil
}
now := time.Now()
for _, chain := range verifiedChains {
for _, cert := range chain {
if now.Before(cert.NotBefore) {
return errors.New("certificate is not valid yet")
}
if now.After(cert.NotAfter) {
return errors.New("certificate is expired")
}
if cert.IsCA {
continue
}
if err = cert.VerifyHostname(hostname); err != nil {
return err
}
}
}
if s.PinManager == nil {
return nil
}
s.pinMu.RLock()
manager := s.PinManager[addr]
s.pinMu.RUnlock()
if manager == nil {
return nil
}
for _, chain := range verifiedChains {
for _, cert := range chain {
if manager.Verify(cert) {
return nil
}
}
}
return errors.New("pin verification failed")
},
}
tlsConn := tls.UClient(conn, &config, tls.HelloCustom)
var fn = s.GetClientHelloSpec
if fn == nil {
fn = GetBrowserClientHelloFunc(s.Browser)
}
specs := fn()
if v, k := ctx.Value(forceHTTP1Key).(bool); k && v {
for _, ext := range specs.Extensions {
switch ext.(type) {
case *tls.ALPNExtension:
ext.(*tls.ALPNExtension).AlpnProtocols = []string{"http/1.1"}
}
}
config.NextProtos = []string{"http/1.1"}
}
if err = tlsConn.ApplyPreset(specs); err != nil {
return nil, errors.New("failed to apply preset: " + err.Error())
}
if err = tlsConn.Handshake(); err != nil {
return nil, errors.New("failed to handshake: " + err.Error())
}
return tlsConn.Conn, nil
}