diff --git a/manager.go b/manager.go index bef0a44..54c2312 100644 --- a/manager.go +++ b/manager.go @@ -31,6 +31,7 @@ type Manager struct { expiresIn time.Duration withIP bool withAgent bool + validate bool genID func() string reject func(error) http.Handler @@ -121,6 +122,15 @@ func WithAgent(w bool) setter { } } +// Validate sets whether IP and User-Agent data +// should be checked on each request to authenticated +// routes. +func Validate(v bool) setter { + return func(m *Manager) { + m.validate = v + } +} + // GenID sets the function which will be called when a new session // is created and ID is being generated. // Defaults to DefaultGenID function. @@ -214,8 +224,9 @@ func (m *Manager) Init(w http.ResponseWriter, r *http.Request, key string) error // Public wraps the provided handler, checks whether the session, associated to // the ID stored in request's cookie, exists in the store and adds it to the // request's context. -// If no valid cookie is provided, session doesn't exist or the store returns -// an error, wrapped handler will be activated nonetheless. +// If no valid cookie is provided, session doesn't exist, the properties of the +// request don't match the ones associated to the session (if validation is +// activated) or the store returns an error, wrapped handler will be activated nonetheless. // Rejection function will be called only for non-http side effects (like error logging), // but response/request control will not be passed to it. func (m *Manager) Public(next http.Handler) http.Handler { @@ -228,8 +239,9 @@ func (m *Manager) Public(next http.Handler) http.Handler { // Auth wraps the provided handler, checks whether the session, associated to // the ID stored in request's cookie, exists in the store and adds it to the // request's context. -// Wrapped handler will be activated only if there are no errors returned from the store -// and the session is found, otherwise, the manager's rejection function will be called. +// Wrapped handler will be activated only if there are no errors returned from the store, +// the session is found and its properties match the ones in the request (if +// validation is activated), otherwise, the manager's rejection function will be called. func (m *Manager) Auth(next http.Handler) http.Handler { return m.wrap(m.reject, next) } @@ -257,6 +269,11 @@ func (m *Manager) wrap(rej func(error) http.Handler, next http.Handler) http.Han return } + if m.validate && !s.isValid(r) { + rej(errors.New("unauthorized")).ServeHTTP(w, r) + return + } + next.ServeHTTP(w, r.WithContext(NewContext(ctx, s))) }) } diff --git a/manager_test.go b/manager_test.go index b03fcde..15092f1 100644 --- a/manager_test.go +++ b/manager_test.go @@ -5,11 +5,14 @@ import ( "encoding/json" "errors" "fmt" + "net" "net/http" "net/http/httptest" "reflect" "testing" "time" + + "xojoc.pw/useragent" ) func TestCookieName(t *testing.T) { @@ -93,6 +96,15 @@ func TestWithAgent(t *testing.T) { } } +func TestValidate(t *testing.T) { + m := Manager{} + val := true + Validate(val)(&m) + if m.validate != val { + t.Errorf("want %t, got %t", val, m.validate) + } +} + func TestGenID(t *testing.T) { m := Manager{} val := func() string { return "" } @@ -307,6 +319,8 @@ func TestInit(t *testing.T) { } func TestPublic(t *testing.T) { + ip := "127.0.0.1" + type check func(*testing.T, *StoreMock, *httptest.ResponseRecorder) checks := func(cc ...check) []check { return cc } @@ -339,7 +353,12 @@ func TestPublic(t *testing.T) { storeStub := func(bRes bool, err error) *StoreMock { return &StoreMock{ FetchByIDFunc: func(_ context.Context, _ string) (Session, bool, error) { - return Session{}, bRes, err + s := Session{ + IP: net.ParseIP(ip), + } + s.Agent.OS = useragent.OSLinux + s.Agent.Browser = "Firefox" + return s, bRes, err }, } } @@ -350,6 +369,7 @@ func TestPublic(t *testing.T) { Store *StoreMock Cookie *http.Cookie Auth bool + IP string Checks []check }{ "Invalid cookie": { @@ -359,6 +379,7 @@ func TestPublic(t *testing.T) { Value: id, }, Auth: false, + IP: ip, Checks: checks( hasResp(http.StatusOK), wasFetchByIDCalled(0, ""), @@ -371,6 +392,7 @@ func TestPublic(t *testing.T) { Value: id, }, Auth: false, + IP: ip, Checks: checks( hasResp(http.StatusOK), wasFetchByIDCalled(1, id), @@ -383,6 +405,20 @@ func TestPublic(t *testing.T) { Value: id, }, Auth: false, + IP: ip, + Checks: checks( + hasResp(http.StatusOK), + wasFetchByIDCalled(1, id), + ), + }, + "IP is invalid": { + Store: storeStub(true, nil), + Cookie: &http.Cookie{ + Name: defaultName, + Value: id, + }, + Auth: false, + IP: "127.0.0.2", Checks: checks( hasResp(http.StatusOK), wasFetchByIDCalled(1, id), @@ -395,6 +431,7 @@ func TestPublic(t *testing.T) { Value: id, }, Auth: true, + IP: ip, Checks: checks( hasResp(http.StatusOK), wasFetchByIDCalled(1, id), @@ -418,7 +455,9 @@ func TestPublic(t *testing.T) { rec := httptest.NewRecorder() req := httptest.NewRequest("GET", "http://example.com", nil) req.AddCookie(c.Cookie) - m := Manager{store: c.Store} + req.Header.Set("X-Forwarded-For", c.IP) + req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux i686; rv:38.0) Gecko/20100101 Firefox/38.0") + m := Manager{store: c.Store, validate: true} m.Defaults() m.Public(next(t, c.Auth)).ServeHTTP(rec, req) for _, ch := range c.Checks { @@ -429,6 +468,8 @@ func TestPublic(t *testing.T) { } func TestAuth(t *testing.T) { + ip := "127.0.0.1" + type check func(*testing.T, *StoreMock, *httptest.ResponseRecorder) checks := func(cc ...check) []check { return cc } @@ -461,7 +502,12 @@ func TestAuth(t *testing.T) { storeStub := func(bRes bool, err error) *StoreMock { return &StoreMock{ FetchByIDFunc: func(_ context.Context, _ string) (Session, bool, error) { - return Session{}, bRes, err + s := Session{ + IP: net.ParseIP(ip), + } + s.Agent.OS = useragent.OSLinux + s.Agent.Browser = "Firefox" + return s, bRes, err }, } } @@ -471,6 +517,7 @@ func TestAuth(t *testing.T) { cc := map[string]struct { Store *StoreMock Cookie *http.Cookie + IP string Checks []check }{ "Invalid cookie": { @@ -479,6 +526,7 @@ func TestAuth(t *testing.T) { Name: "incorrect", Value: id, }, + IP: ip, Checks: checks( hasResp(http.StatusUnauthorized, true), wasFetchByIDCalled(0, ""), @@ -490,6 +538,7 @@ func TestAuth(t *testing.T) { Name: defaultName, Value: id, }, + IP: ip, Checks: checks( hasResp(http.StatusUnauthorized, true), wasFetchByIDCalled(1, id), @@ -501,6 +550,19 @@ func TestAuth(t *testing.T) { Name: defaultName, Value: id, }, + IP: ip, + Checks: checks( + hasResp(http.StatusUnauthorized, true), + wasFetchByIDCalled(1, id), + ), + }, + "IP is invalid": { + Store: storeStub(true, nil), + Cookie: &http.Cookie{ + Name: defaultName, + Value: id, + }, + IP: "127.0.0.2", Checks: checks( hasResp(http.StatusUnauthorized, true), wasFetchByIDCalled(1, id), @@ -512,6 +574,7 @@ func TestAuth(t *testing.T) { Name: defaultName, Value: id, }, + IP: ip, Checks: checks( hasResp(http.StatusOK, false), wasFetchByIDCalled(1, id), @@ -535,7 +598,9 @@ func TestAuth(t *testing.T) { rec := httptest.NewRecorder() req := httptest.NewRequest("GET", "http://example.com", nil) req.AddCookie(c.Cookie) - m := Manager{store: c.Store} + req.Header.Set("X-Forwarded-For", c.IP) + req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux i686; rv:38.0) Gecko/20100101 Firefox/38.0") + m := Manager{store: c.Store, validate: true} m.Defaults() m.Auth(next(t)).ServeHTTP(rec, req) for _, ch := range c.Checks { diff --git a/session.go b/session.go index 755ed7a..de4ec4a 100644 --- a/session.go +++ b/session.go @@ -48,6 +48,29 @@ type Session struct { } `json:"agent"` } +// isValid checks whether the incoming request's properties match +// active session's properties. +func (s Session) isValid(r *http.Request) bool { + ip := true + if len(s.IP) != 0 { + ip = s.IP.Equal(readIP(r)) + } + + a := useragent.Parse(r.Header.Get("User-Agent")) + + os := true + if s.Agent.OS != "" { + os = s.Agent.OS == a.OS + } + + browser := true + if s.Agent.Browser != "" { + browser = s.Agent.Browser == a.Name + } + + return ip && os && browser +} + // newSession creates a new Session with the data extracted from // the provided request, user key and a freshly generated ID. func (m *Manager) newSession(r *http.Request, key string) Session { diff --git a/session_test.go b/session_test.go index 4c86e14..c90a2e9 100644 --- a/session_test.go +++ b/session_test.go @@ -12,6 +12,99 @@ import ( "xojoc.pw/useragent" ) +func TestIsValid(t *testing.T) { + ses := Session{ + IP: net.ParseIP("127.0.0.1"), + } + + ses.Agent.OS = useragent.OSWindows + ses.Agent.Browser = "Chrome" + + req := httptest.NewRequest("GET", "http://example.com/", nil) + req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/47.0.2526.111 Safari/537.36") + req.RemoteAddr = "127.0.0.1:3000" + + cc := map[string]struct { + Req *http.Request + Session Session + Res bool + }{ + "Invalid IP": { + Req: func() *http.Request { + creq := httptest.NewRequest("GET", "http://example.com/", nil) + creq.Header.Set("User-Agent", req.Header.Get("User-Agent")) + creq.RemoteAddr = "127.0.0.2:3000" + return creq + }(), + Session: ses, + Res: false, + }, + "Invalid User-Agent browser": { + Req: func() *http.Request { + creq := httptest.NewRequest("GET", "http://example.com/", nil) + creq.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/42.0.2311.135 Safari/537.36 Edge/12.246") + creq.RemoteAddr = req.RemoteAddr + return creq + }(), + Session: ses, + Res: false, + }, + "Invalid User-Agent os": { + Req: func() *http.Request { + creq := httptest.NewRequest("GET", "http://example.com/", nil) + creq.Header.Set("User-Agent", "Mozilla/5.0 (Linux; Android 5.1.1; SM-G928X Build/LMY47X) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/47.0.2526.83 Mobile Safari/537.36") + creq.RemoteAddr = req.RemoteAddr + return creq + }(), + Session: ses, + Res: false, + }, + "Successful all fields except ip validation": { + Req: req, + Session: func() Session { + cses := ses + cses.IP = nil + return cses + }(), + Res: true, + }, + "Successful all fields except os validation": { + Req: req, + Session: func() Session { + cses := ses + cses.Agent.OS = "" + return cses + }(), + Res: true, + }, + "Successful all fields except browser validation": { + Req: req, + Session: func() Session { + cses := ses + cses.Agent.Browser = "" + return cses + }(), + Res: true, + }, + "Successful all fields validation": { + Req: req, + Session: ses, + Res: true, + }, + } + + for cn, c := range cc { + c := c + t.Run(cn, func(t *testing.T) { + t.Parallel() + res := c.Session.isValid(c.Req) + if res != c.Res { + t.Errorf("want %t, got %t", c.Res, res) + } + }) + } +} + func TestNewSession(t *testing.T) { m := Manager{ expiresIn: time.Hour,