Skip to content

Commit

Permalink
Merge pull request #2 from swithek/property-check
Browse files Browse the repository at this point in the history
Session properties validation and optional automatic termination
  • Loading branch information
swithek authored Oct 16, 2019
2 parents b729ceb + c22e258 commit 8218667
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 8 deletions.
25 changes: 21 additions & 4 deletions manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type Manager struct {
expiresIn time.Duration
withIP bool
withAgent bool
validate bool

genID func() string
reject func(error) http.Handler
Expand Down Expand Up @@ -121,6 +122,15 @@ func WithAgent(w bool) setter {
}
}

// Validate sets whether IP and User-Agent data
// should be checked on each request to authenticated
// routes.
func Validate(v bool) setter {
return func(m *Manager) {
m.validate = v
}
}

// GenID sets the function which will be called when a new session
// is created and ID is being generated.
// Defaults to DefaultGenID function.
Expand Down Expand Up @@ -214,8 +224,9 @@ func (m *Manager) Init(w http.ResponseWriter, r *http.Request, key string) error
// Public wraps the provided handler, checks whether the session, associated to
// the ID stored in request's cookie, exists in the store and adds it to the
// request's context.
// If no valid cookie is provided, session doesn't exist or the store returns
// an error, wrapped handler will be activated nonetheless.
// If no valid cookie is provided, session doesn't exist, the properties of the
// request don't match the ones associated to the session (if validation is
// activated) or the store returns an error, wrapped handler will be activated nonetheless.
// Rejection function will be called only for non-http side effects (like error logging),
// but response/request control will not be passed to it.
func (m *Manager) Public(next http.Handler) http.Handler {
Expand All @@ -228,8 +239,9 @@ func (m *Manager) Public(next http.Handler) http.Handler {
// Auth wraps the provided handler, checks whether the session, associated to
// the ID stored in request's cookie, exists in the store and adds it to the
// request's context.
// Wrapped handler will be activated only if there are no errors returned from the store
// and the session is found, otherwise, the manager's rejection function will be called.
// Wrapped handler will be activated only if there are no errors returned from the store,
// the session is found and its properties match the ones in the request (if
// validation is activated), otherwise, the manager's rejection function will be called.
func (m *Manager) Auth(next http.Handler) http.Handler {
return m.wrap(m.reject, next)
}
Expand Down Expand Up @@ -257,6 +269,11 @@ func (m *Manager) wrap(rej func(error) http.Handler, next http.Handler) http.Han
return
}

if m.validate && !s.isValid(r) {
rej(errors.New("unauthorized")).ServeHTTP(w, r)
return
}

next.ServeHTTP(w, r.WithContext(NewContext(ctx, s)))
})
}
Expand Down
73 changes: 69 additions & 4 deletions manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"

"xojoc.pw/useragent"
)

func TestCookieName(t *testing.T) {
Expand Down Expand Up @@ -93,6 +96,15 @@ func TestWithAgent(t *testing.T) {
}
}

func TestValidate(t *testing.T) {
m := Manager{}
val := true
Validate(val)(&m)
if m.validate != val {
t.Errorf("want %t, got %t", val, m.validate)
}
}

func TestGenID(t *testing.T) {
m := Manager{}
val := func() string { return "" }
Expand Down Expand Up @@ -307,6 +319,8 @@ func TestInit(t *testing.T) {
}

func TestPublic(t *testing.T) {
ip := "127.0.0.1"

type check func(*testing.T, *StoreMock, *httptest.ResponseRecorder)

checks := func(cc ...check) []check { return cc }
Expand Down Expand Up @@ -339,7 +353,12 @@ func TestPublic(t *testing.T) {
storeStub := func(bRes bool, err error) *StoreMock {
return &StoreMock{
FetchByIDFunc: func(_ context.Context, _ string) (Session, bool, error) {
return Session{}, bRes, err
s := Session{
IP: net.ParseIP(ip),
}
s.Agent.OS = useragent.OSLinux
s.Agent.Browser = "Firefox"
return s, bRes, err
},
}
}
Expand All @@ -350,6 +369,7 @@ func TestPublic(t *testing.T) {
Store *StoreMock
Cookie *http.Cookie
Auth bool
IP string
Checks []check
}{
"Invalid cookie": {
Expand All @@ -359,6 +379,7 @@ func TestPublic(t *testing.T) {
Value: id,
},
Auth: false,
IP: ip,
Checks: checks(
hasResp(http.StatusOK),
wasFetchByIDCalled(0, ""),
Expand All @@ -371,6 +392,7 @@ func TestPublic(t *testing.T) {
Value: id,
},
Auth: false,
IP: ip,
Checks: checks(
hasResp(http.StatusOK),
wasFetchByIDCalled(1, id),
Expand All @@ -383,6 +405,20 @@ func TestPublic(t *testing.T) {
Value: id,
},
Auth: false,
IP: ip,
Checks: checks(
hasResp(http.StatusOK),
wasFetchByIDCalled(1, id),
),
},
"IP is invalid": {
Store: storeStub(true, nil),
Cookie: &http.Cookie{
Name: defaultName,
Value: id,
},
Auth: false,
IP: "127.0.0.2",
Checks: checks(
hasResp(http.StatusOK),
wasFetchByIDCalled(1, id),
Expand All @@ -395,6 +431,7 @@ func TestPublic(t *testing.T) {
Value: id,
},
Auth: true,
IP: ip,
Checks: checks(
hasResp(http.StatusOK),
wasFetchByIDCalled(1, id),
Expand All @@ -418,7 +455,9 @@ func TestPublic(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "http://example.com", nil)
req.AddCookie(c.Cookie)
m := Manager{store: c.Store}
req.Header.Set("X-Forwarded-For", c.IP)
req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux i686; rv:38.0) Gecko/20100101 Firefox/38.0")
m := Manager{store: c.Store, validate: true}
m.Defaults()
m.Public(next(t, c.Auth)).ServeHTTP(rec, req)
for _, ch := range c.Checks {
Expand All @@ -429,6 +468,8 @@ func TestPublic(t *testing.T) {
}

func TestAuth(t *testing.T) {
ip := "127.0.0.1"

type check func(*testing.T, *StoreMock, *httptest.ResponseRecorder)

checks := func(cc ...check) []check { return cc }
Expand Down Expand Up @@ -461,7 +502,12 @@ func TestAuth(t *testing.T) {
storeStub := func(bRes bool, err error) *StoreMock {
return &StoreMock{
FetchByIDFunc: func(_ context.Context, _ string) (Session, bool, error) {
return Session{}, bRes, err
s := Session{
IP: net.ParseIP(ip),
}
s.Agent.OS = useragent.OSLinux
s.Agent.Browser = "Firefox"
return s, bRes, err
},
}
}
Expand All @@ -471,6 +517,7 @@ func TestAuth(t *testing.T) {
cc := map[string]struct {
Store *StoreMock
Cookie *http.Cookie
IP string
Checks []check
}{
"Invalid cookie": {
Expand All @@ -479,6 +526,7 @@ func TestAuth(t *testing.T) {
Name: "incorrect",
Value: id,
},
IP: ip,
Checks: checks(
hasResp(http.StatusUnauthorized, true),
wasFetchByIDCalled(0, ""),
Expand All @@ -490,6 +538,7 @@ func TestAuth(t *testing.T) {
Name: defaultName,
Value: id,
},
IP: ip,
Checks: checks(
hasResp(http.StatusUnauthorized, true),
wasFetchByIDCalled(1, id),
Expand All @@ -501,6 +550,19 @@ func TestAuth(t *testing.T) {
Name: defaultName,
Value: id,
},
IP: ip,
Checks: checks(
hasResp(http.StatusUnauthorized, true),
wasFetchByIDCalled(1, id),
),
},
"IP is invalid": {
Store: storeStub(true, nil),
Cookie: &http.Cookie{
Name: defaultName,
Value: id,
},
IP: "127.0.0.2",
Checks: checks(
hasResp(http.StatusUnauthorized, true),
wasFetchByIDCalled(1, id),
Expand All @@ -512,6 +574,7 @@ func TestAuth(t *testing.T) {
Name: defaultName,
Value: id,
},
IP: ip,
Checks: checks(
hasResp(http.StatusOK, false),
wasFetchByIDCalled(1, id),
Expand All @@ -535,7 +598,9 @@ func TestAuth(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "http://example.com", nil)
req.AddCookie(c.Cookie)
m := Manager{store: c.Store}
req.Header.Set("X-Forwarded-For", c.IP)
req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux i686; rv:38.0) Gecko/20100101 Firefox/38.0")
m := Manager{store: c.Store, validate: true}
m.Defaults()
m.Auth(next(t)).ServeHTTP(rec, req)
for _, ch := range c.Checks {
Expand Down
23 changes: 23 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,29 @@ type Session struct {
} `json:"agent"`
}

// isValid checks whether the incoming request's properties match
// active session's properties.
func (s Session) isValid(r *http.Request) bool {
ip := true
if len(s.IP) != 0 {
ip = s.IP.Equal(readIP(r))
}

a := useragent.Parse(r.Header.Get("User-Agent"))

os := true
if s.Agent.OS != "" {
os = s.Agent.OS == a.OS
}

browser := true
if s.Agent.Browser != "" {
browser = s.Agent.Browser == a.Name
}

return ip && os && browser
}

// newSession creates a new Session with the data extracted from
// the provided request, user key and a freshly generated ID.
func (m *Manager) newSession(r *http.Request, key string) Session {
Expand Down
93 changes: 93 additions & 0 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,99 @@ import (
"xojoc.pw/useragent"
)

func TestIsValid(t *testing.T) {
ses := Session{
IP: net.ParseIP("127.0.0.1"),
}

ses.Agent.OS = useragent.OSWindows
ses.Agent.Browser = "Chrome"

req := httptest.NewRequest("GET", "http://example.com/", nil)
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/47.0.2526.111 Safari/537.36")
req.RemoteAddr = "127.0.0.1:3000"

cc := map[string]struct {
Req *http.Request
Session Session
Res bool
}{
"Invalid IP": {
Req: func() *http.Request {
creq := httptest.NewRequest("GET", "http://example.com/", nil)
creq.Header.Set("User-Agent", req.Header.Get("User-Agent"))
creq.RemoteAddr = "127.0.0.2:3000"
return creq
}(),
Session: ses,
Res: false,
},
"Invalid User-Agent browser": {
Req: func() *http.Request {
creq := httptest.NewRequest("GET", "http://example.com/", nil)
creq.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/42.0.2311.135 Safari/537.36 Edge/12.246")
creq.RemoteAddr = req.RemoteAddr
return creq
}(),
Session: ses,
Res: false,
},
"Invalid User-Agent os": {
Req: func() *http.Request {
creq := httptest.NewRequest("GET", "http://example.com/", nil)
creq.Header.Set("User-Agent", "Mozilla/5.0 (Linux; Android 5.1.1; SM-G928X Build/LMY47X) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/47.0.2526.83 Mobile Safari/537.36")
creq.RemoteAddr = req.RemoteAddr
return creq
}(),
Session: ses,
Res: false,
},
"Successful all fields except ip validation": {
Req: req,
Session: func() Session {
cses := ses
cses.IP = nil
return cses
}(),
Res: true,
},
"Successful all fields except os validation": {
Req: req,
Session: func() Session {
cses := ses
cses.Agent.OS = ""
return cses
}(),
Res: true,
},
"Successful all fields except browser validation": {
Req: req,
Session: func() Session {
cses := ses
cses.Agent.Browser = ""
return cses
}(),
Res: true,
},
"Successful all fields validation": {
Req: req,
Session: ses,
Res: true,
},
}

for cn, c := range cc {
c := c
t.Run(cn, func(t *testing.T) {
t.Parallel()
res := c.Session.isValid(c.Req)
if res != c.Res {
t.Errorf("want %t, got %t", c.Res, res)
}
})
}
}

func TestNewSession(t *testing.T) {
m := Manager{
expiresIn: time.Hour,
Expand Down

0 comments on commit 8218667

Please sign in to comment.