Skip to content

Commit

Permalink
[+] refacto tests + code improvment
Browse files Browse the repository at this point in the history
  • Loading branch information
Noooste committed Jan 4, 2024
1 parent c363a90 commit 006e320
Show file tree
Hide file tree
Showing 28 changed files with 319 additions and 359 deletions.
65 changes: 25 additions & 40 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@ import (
"context"
"crypto/x509"
"errors"
"fmt"
"github.com/Noooste/fhttp/http2"
tls "github.com/Noooste/utls"
"net"
"net/url"
"runtime/debug"
"strings"
"sync"
"time"
Expand All @@ -22,7 +20,7 @@ type Conn struct {

Conn net.Conn // Tcp connection

Pins *PinManager // pin manager
PinManager *PinManager // pin manager

TimeOut time.Duration
InsecureSkipVerify bool
Expand Down Expand Up @@ -51,6 +49,10 @@ func (c *Conn) SetContext(ctx context.Context) {
c.ctx = ctx
}

func (c *Conn) GetContext() context.Context {
return c.ctx
}

type ConnPool struct {
hosts map[string]*Conn
mu *sync.RWMutex
Expand Down Expand Up @@ -175,7 +177,9 @@ func (c *Conn) checkTLS() bool {
func (c *Conn) tryUpgradeHTTP2(tr *http2.Transport) bool {
if c.HTTP2 != nil && c.HTTP2.CanTakeNewRequest() {
return true
} else if c.TLS.ConnectionState().NegotiatedProtocol == http2.NextProtoTLS {
}

if c.TLS.ConnectionState().NegotiatedProtocol == http2.NextProtoTLS {
var err error
c.HTTP2, err = tr.NewClientConn(c.TLS)
return err == nil
Expand All @@ -197,7 +201,7 @@ func (c *Conn) Close() {
_ = c.HTTP2.Close()
c.HTTP2 = nil
}
c.Pins = nil
c.PinManager = nil
}

func (s *Session) initConn(req *Request) (conn *Conn, err error) {
Expand Down Expand Up @@ -226,10 +230,11 @@ func (s *Session) initConn(req *Request) (conn *Conn, err error) {
if conn.Conn == nil {
var dialContext func(ctx context.Context, network, addr string) (net.Conn, error)

if s.proxyDialer != nil {
s.proxyDialer.ForceHTTP2 = s.H2Proxy
s.proxyDialer.tr = s.tr2
dialContext = s.proxyDialer.DialContext
if s.ProxyDialer != nil {
s.ProxyDialer.ForceHTTP2 = s.H2Proxy
s.ProxyDialer.tr = s.HTTP2Transport
dialContext = s.ProxyDialer.DialContext
s.ProxyDialer.Dialer.Timeout = conn.TimeOut
} else {
dialContext = (&net.Dialer{Timeout: conn.TimeOut}).DialContext
}
Expand All @@ -255,7 +260,7 @@ func (s *Session) initConn(req *Request) (conn *Conn, err error) {

if req.parsedUrl.Scheme != SchemeWss {
// if tls connection is established, we can try to upgrade it to http2
conn.tryUpgradeHTTP2(s.tr2)
conn.tryUpgradeHTTP2(s.HTTP2Transport)
}

case SchemeHttp, SchemeWs:
Expand All @@ -269,37 +274,17 @@ func (s *Session) initConn(req *Request) (conn *Conn, err error) {
}

func (c *Conn) NewTLS(addr string) (err error) {
var done = make(chan bool, 1)
defer close(done)

go func() {
defer func() {
if err := recover(); err != nil {
done <- false
fmt.Println("panic:", err)
debug.PrintStack()
}
}()
do := false

var do bool
if c.Pins == nil && !c.InsecureSkipVerify {
c.Pins = NewPinManager()
do = true
}
if c.PinManager == nil && !c.InsecureSkipVerify {
c.PinManager = NewPinManager()
do = true
}

if !c.InsecureSkipVerify && (do || c.Pins.redo) {
if err = c.Pins.New(addr); err != nil {
done <- false
return
}
if !c.InsecureSkipVerify && (do || c.PinManager.redo) {
if err = c.PinManager.New(addr); err != nil {
return errors.New("pin verification failed")
}

//check if channel is closed
done <- true
}()

if !<-done {
return errors.New("pin verification failed")
}

var hostname = strings.Split(addr, ":")[0]
Expand All @@ -308,13 +293,13 @@ func (c *Conn) NewTLS(addr string) (err error) {
ServerName: hostname,
InsecureSkipVerify: c.InsecureSkipVerify,
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
if c.Pins == nil {
if c.PinManager == nil {
return nil
}

for _, chain := range verifiedChains {
for _, cert := range chain {
if c.Pins.Verify(cert) {
if c.PinManager.Verify(cert) {
return nil
}
}
Expand Down
19 changes: 9 additions & 10 deletions connection_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type proxyDialer struct {
DialTLS func(network string, address string) (net.Conn, string, error)

h2Mu sync.Mutex
h2Conn *http2.ClientConn
H2Conn *http2.ClientConn
conn net.Conn

tr *http2.Transport
Expand Down Expand Up @@ -71,7 +71,7 @@ func (s *Session) assignProxy(proxy string) error {
return fmt.Errorf(invalidProxy, proxy, "scheme "+parsed.Scheme+" is not supported")
}

s.proxyDialer = &proxyDialer{
s.ProxyDialer = &proxyDialer{
ProxyURL: parsed,
DefaultHeader: make(http.Header),
}
Expand All @@ -83,7 +83,7 @@ func (s *Session) assignProxy(proxy string) error {
} else {
auth := parsed.User.Username() + ":" + password
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(auth))
s.proxyDialer.DefaultHeader.Add("Proxy-Authorization", basicAuth)
s.ProxyDialer.DefaultHeader.Add("Proxy-Authorization", basicAuth)
}
}
}
Expand Down Expand Up @@ -153,7 +153,6 @@ func (c *proxyDialer) DialContext(ctx context.Context, network, address string)
if err != nil {
return nil, err
}

return dial.(proxy.ContextDialer).DialContext(ctx, network, address)
}

Expand All @@ -176,10 +175,10 @@ func (c *proxyDialer) DialContext(ctx context.Context, network, address string)

c.h2Mu.Lock()
unlocked := false
if c.h2Conn != nil && c.conn != nil {
if c.h2Conn.CanTakeNewRequest() {
if c.H2Conn != nil && c.conn != nil {
if c.H2Conn.CanTakeNewRequest() {
rc := c.conn
cc := c.h2Conn
cc := c.H2Conn
c.h2Mu.Unlock()
unlocked = true
proxyConn, err := c.connectHTTP2(req, rc, cc)
Expand All @@ -193,7 +192,7 @@ func (c *proxyDialer) DialContext(ctx context.Context, network, address string)
c.h2Mu.Unlock()
}

rawConn, negotiatedProtocol, err := c.initProxyConn(ctx, network)
rawConn, negotiatedProtocol, err := c.InitProxyConn(ctx, network)

if err != nil {
return nil, err
Expand All @@ -208,7 +207,7 @@ func (c *proxyDialer) DialContext(ctx context.Context, network, address string)
return proxyConn, nil
}

func (c *proxyDialer) initProxyConn(ctx context.Context, network string) (rawConn net.Conn, negotiatedProtocol string, err error) {
func (c *proxyDialer) InitProxyConn(ctx context.Context, network string) (rawConn net.Conn, negotiatedProtocol string, err error) {
switch c.ProxyURL.Scheme {
case SchemeHttp:
rawConn, err = c.Dialer.DialContext(ctx, network, c.ProxyURL.Host)
Expand Down Expand Up @@ -252,7 +251,7 @@ func (c *proxyDialer) connect(req *http.Request, conn net.Conn, negotiatedProtoc
if h2clientConn, err := c.tr.NewClientConn(conn); err == nil {
if proxyConn, err := c.connectHTTP2(req, conn, h2clientConn); err == nil {
c.h2Mu.Lock()
c.h2Conn = h2clientConn
c.H2Conn = h2clientConn
c.conn = conn
c.h2Mu.Unlock()
return proxyConn, err
Expand Down
23 changes: 18 additions & 5 deletions cookies.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
package azuretls

import (
"bytes"
http "github.com/Noooste/fhttp"
"strings"
)

var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-")

func cookiesToString(cookies []*http.Cookie) string {
var f = make([]string, 0, len(cookies))
func CookiesToString(cookies []*http.Cookie) string {
var buf bytes.Buffer

var length = 0
for _, el := range cookies {
f = append(f, cookieNameSanitizer.Replace(el.Name)+"="+cookieNameSanitizer.Replace(el.Value))
length += len(el.Name) + len(el.Value) + 3
}
return strings.Join(f, "; ")

buf.Grow(length)
for _, el := range cookies {
buf.WriteString(cookieNameSanitizer.Replace(el.Name))
buf.WriteByte('=')
buf.WriteString(cookieNameSanitizer.Replace(el.Value))
buf.WriteString("; ")
}

buf.Truncate(buf.Len() - 2)
return buf.String()
}

func getCookiesMap(cookies []*http.Cookie) map[string]string {
var result = make(map[string]string)
var result = make(map[string]string, len(cookies))

for _, cookie := range cookies {
result[cookie.Name] = cookie.Value
Expand Down
12 changes: 6 additions & 6 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ func (ph PHeader) GetDefault() {

// Clone returns a copy of the header.
func (oh *OrderedHeaders) Clone() OrderedHeaders {
var clone = make(OrderedHeaders, len(*oh))
var clone = make(OrderedHeaders, 0, len(*oh))

for i, header := range *oh {
var fieldClone = make([]string, len(header))
for j, field := range header {
fieldClone[j] = field
for _, header := range *oh {
var fieldClone = make([]string, 0, len(header))
for _, field := range header {
fieldClone = append(fieldClone, field)
}
clone[i] = fieldClone
clone = append(clone, fieldClone)
}

return clone
Expand Down
14 changes: 7 additions & 7 deletions http2.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ const (
func (s *Session) ApplyHTTP2(fp string) error {
// check if HTTP2 is already initialized
// if not initialize it
if s.tr2 == nil {
if s.tr == nil {
if s.HTTP2Transport == nil {
if s.Transport == nil {
s.initHTTP1()
}

var err error
s.tr2, err = http2.ConfigureTransports(s.tr)
s.HTTP2Transport, err = http2.ConfigureTransports(s.Transport)
if err != nil {
return err
}
Expand All @@ -54,19 +54,19 @@ func (s *Session) ApplyHTTP2(fp string) error {
preHeader = split[3]
)

if err := applySettings(settings, s.tr2); err != nil {
if err := applySettings(settings, s.HTTP2Transport); err != nil {
return err
}

if err := applyWindowUpdate(windowUpdate, s.tr2); err != nil {
if err := applyWindowUpdate(windowUpdate, s.HTTP2Transport); err != nil {
return err
}

if err := applyPriorities(priorities, s.tr2); err != nil {
if err := applyPriorities(priorities, s.HTTP2Transport); err != nil {
return err
}

if err := applyPreHeader(preHeader, &s.PHeader, s.tr2); err != nil {
if err := applyPreHeader(preHeader, &s.PHeader, s.HTTP2Transport); err != nil {
return err
}

Expand Down
27 changes: 21 additions & 6 deletions pinner.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,21 @@ func (p *PinManager) New(addr string) (err error) {
return nil
}

func (p *PinManager) GetPins() []string {
p.mu.RLock()
defer p.mu.RUnlock()

var pins = make([]string, 0, len(p.m))
for k, v := range p.m {
if !v {
continue
}
pins = append(pins, k)
}

return pins
}

// AddPins associates a set of certificate pins with a given URL within
// a session. This allows for URL-specific pinning, useful in scenarios
// where different services (URLs) are trusted with different certificates.
Expand All @@ -99,12 +114,12 @@ func (s *Session) AddPins(u *url.URL, pins []string) error {
conn.mu.Lock()
defer conn.mu.Unlock()

if conn.Pins == nil {
conn.Pins = NewPinManager()
if conn.PinManager == nil {
conn.PinManager = NewPinManager()
}

for _, pin := range pins {
conn.Pins.m[pin] = true
conn.PinManager.m[pin] = true
}

return nil
Expand All @@ -119,11 +134,11 @@ func (s *Session) ClearPins(u *url.URL) error {
conn.mu.Lock()
defer conn.mu.Unlock()

for k := range conn.Pins.m {
conn.Pins.m[k] = false
for k := range conn.PinManager.m {
conn.PinManager.m[k] = false
}

conn.Pins.redo = true
conn.PinManager.redo = true

return nil
}
Loading

0 comments on commit 006e320

Please sign in to comment.