Skip to content

Commit

Permalink
feat(oauth): association userID with clientID
Browse files Browse the repository at this point in the history
  • Loading branch information
blancinot committed Oct 7, 2019
1 parent 8b6cce6 commit b1695a4
Show file tree
Hide file tree
Showing 13 changed files with 321 additions and 79 deletions.
6 changes: 4 additions & 2 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"errors"
"time"

"github.com/dgrijalva/jwt-go"
"github.com/emicklei/go-restful"
jwt "github.com/dgrijalva/jwt-go"
restful "github.com/emicklei/go-restful"
)

var (
Expand All @@ -16,6 +16,7 @@ var (
// Authenticator is the interface for authn backends
type Authenticator interface {
Authenticate(user, password string, expiresAt time.Time) (claims jwt.Claims, err error)
FindUser(clientID, provider string, expiresAt time.Time) (user string, claims jwt.Claims, err error)
}

// API registering with restful
Expand All @@ -36,5 +37,6 @@ func (api *API) Register() *restful.WebService {
api.registerKeystone(ws)
api.registerK8sAuthenticator(ws)
api.registerCertificate(ws)
api.registerOauth(ws)
return ws
}
2 changes: 1 addition & 1 deletion api/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package api
import (
"time"

"github.com/dgrijalva/jwt-go"
jwt "github.com/dgrijalva/jwt-go"

"github.com/mcluseau/autentigo/auth"
)
Expand Down
112 changes: 112 additions & 0 deletions api/oauth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package api

import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"strings"
"time"

restful "github.com/emicklei/go-restful"
)

const bearerPrefix = "Bearer "

func (api *API) registerOauth(ws *restful.WebService) {
ws.
Route(ws.GET("/oauth/{provider}").
To(api.oauthAuthenticate).
Doc("Authenticate using oauth token").
Param(restful.HeaderParameter("Authorization", "Oauth authorization header")).
Produces("application/json").
Writes(AuthResponse{}))
}

func (api *API) oauthAuthenticate(request *restful.Request, response *restful.Response) {
defer func() {
if err := recover(); err != nil {
// unhandled error
WriteError(err.(error), response)
}
}()

authHeader := request.HeaderParameter("Authorization")
if !strings.HasPrefix(authHeader, bearerPrefix) {
response.WriteErrorString(http.StatusUnauthorized, "missing bearer prefix")
return
}

accessToken := authHeader[len(bearerPrefix):]
provider := request.PathParameter("provider")

baseURL, err := oauthClientIdentityURL(provider)
if err != nil {
response.WriteError(http.StatusBadRequest, err)
return
}

identityResponse, err := http.Get(baseURL + "?access_token=" + accessToken)
if err != nil {
response.WriteError(http.StatusUnauthorized, fmt.Errorf("failed getting client identity by oauth: %s", err.Error()))
return
}

defer identityResponse.Body.Close()
contents, err := ioutil.ReadAll(identityResponse.Body)
if err != nil {
response.WriteError(http.StatusUnprocessableEntity, fmt.Errorf("failed reading response body: %s", err.Error()))
return
}

var clientIdentity map[string]interface{}
if err := json.Unmarshal(contents, &clientIdentity); err != nil {
response.WriteError(http.StatusUnprocessableEntity, fmt.Errorf("failed unmarshalling contents: %s", err.Error()))
return
}

id := safeStringValue(clientIdentity, "id")
if len(id) == 0 {
id = safeStringValue(clientIdentity, "sub") // different between some oauth providers
if len(id) == 0 {
response.WriteErrorString(http.StatusUnprocessableEntity, "client identity given by oauth is unprocessable")
return
}
}

exp := time.Now().Add(api.TokenDuration)
user, claims, err := api.Authenticator.FindUser(id, provider, exp)
if err != nil {
response.WriteError(http.StatusUnprocessableEntity, fmt.Errorf("associated user not found: %s", err.Error()))
return
}
_, tokenString, err := api.createToken(user, claims)
if err != nil {
panic(err)
}

_, err = api.checkToken(tokenString)
if err != nil {
panic(err)
}

response.WriteEntity(&AuthResponse{tokenString, claims})
}

func safeStringValue(m map[string]interface{}, field string) string {
v, ok := m[field]
if !ok {
return ""
}
return v.(string)
}

func oauthClientIdentityURL(provider string) (value string, err error) {
urlEnv := strings.ToUpper(provider) + "_USERIDENTITYURL"
value = os.Getenv(urlEnv)
if len(value) == 0 {
err = fmt.Errorf("client identity url given by provider %s is missing, please verify autentigo configuration [%s]", provider, urlEnv)
}
return
}
56 changes: 49 additions & 7 deletions auth/etcd/etcd.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"log"
"os"
"path"
Expand All @@ -17,6 +18,10 @@ import (
"github.com/mcluseau/autentigo/auth"
)

const (
oauthprefix = "/oauth"
)

// New Authenticator with etcd backend
func New(prefix string, endpoints []string) api.Authenticator {
client, err := clientv3.New(clientv3.Config{
Expand Down Expand Up @@ -64,7 +69,29 @@ func (a *etcdAuth) Authenticate(user, password string, expiresAt time.Time) (cla
ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
defer cancel()

resp, err := a.client.Get(ctx, path.Join(a.prefix, user))
u := &User{}
if u, err = a.getUser(ctx, user); err != nil {
return
}

if u.PasswordHash != passwordHash {
err = api.ErrInvalidAuthentication
return
}

claims = auth.Claims{
StandardClaims: jwt.StandardClaims{
IssuedAt: time.Now().Unix(),
ExpiresAt: expiresAt.Unix(),
Subject: user,
},
ExtraClaims: u.ExtraClaims,
}
return
}

func (a *etcdAuth) getUser(ctx context.Context, userID string) (user *User, err error) {
resp, err := a.client.Get(ctx, path.Join(a.prefix, userID))
if err != nil {
return
}
Expand All @@ -74,23 +101,38 @@ func (a *etcdAuth) Authenticate(user, password string, expiresAt time.Time) (cla
return
}

u := User{}
if err = json.Unmarshal(resp.Kvs[0].Value, &u); err != nil {
err = json.Unmarshal(resp.Kvs[0].Value, user)
return
}

func (a *etcdAuth) FindUser(clientID, provider string, expiresAt time.Time) (userID string, claims jwt.Claims, err error) {

ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
defer cancel()

var resp *clientv3.GetResponse
if resp, err = a.client.Get(ctx, path.Join(oauthprefix, a.prefix, provider, clientID)); err != nil {
return
}

if len(resp.Kvs) == 0 {
err = errors.New("unknown user")
return
}

if u.PasswordHash != passwordHash {
err = api.ErrInvalidAuthentication
userID = string(resp.Kvs[0].Value)
user := &User{}
if user, err = a.getUser(ctx, userID); err != nil {
return
}

claims = auth.Claims{
StandardClaims: jwt.StandardClaims{
IssuedAt: time.Now().Unix(),
ExpiresAt: expiresAt.Unix(),
Subject: user,
Subject: userID,
},
ExtraClaims: u.ExtraClaims,
ExtraClaims: user.ExtraClaims,
}
return
}
10 changes: 8 additions & 2 deletions auth/ldap-bind/ldap-bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ package ldapbind

import (
"crypto/tls"
"errors"
"fmt"
"log"
"net/url"
"time"

"github.com/dgrijalva/jwt-go"
jwt "github.com/dgrijalva/jwt-go"
"github.com/mcluseau/autentigo/api"
"gopkg.in/ldap.v2"
ldap "gopkg.in/ldap.v2"
)

// New Authenticator with ldap backend
Expand Down Expand Up @@ -64,3 +65,8 @@ func (a auth) Authenticate(user, password string, expiresAt time.Time) (jwt.Clai
Subject: user,
}, nil
}

func (a auth) FindUser(clientID, provider string, expiresAt time.Time) (userID string, claims jwt.Claims, err error) {
err = errors.New("inconsistent with Ldap backend")
return
}
6 changes: 6 additions & 0 deletions auth/sql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"strings"
"time"
Expand Down Expand Up @@ -74,3 +75,8 @@ func (sa sqlAuth) Authenticate(user, password string, expiresAt time.Time) (clai

return
}

func (sa sqlAuth) FindUser(clientID, provider string, expiresAt time.Time) (userID string, claims jwt.Claims, err error) {
err = errors.New("Not implemented yet")
return
}
8 changes: 7 additions & 1 deletion auth/stupid-auth/stupid-auth.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package stupidauth

import (
"errors"
"time"

"github.com/dgrijalva/jwt-go"
jwt "github.com/dgrijalva/jwt-go"
"github.com/mcluseau/autentigo/api"
)

Expand All @@ -23,3 +24,8 @@ func (sa stupidAuth) Authenticate(user, password string, expiresAt time.Time) (j
Subject: user,
}, nil
}

func (sa stupidAuth) FindUser(clientID, provider string, expiresAt time.Time) (userID string, claims jwt.Claims, err error) {
err = errors.New("inconsistent with stupid auth")
return
}
6 changes: 6 additions & 0 deletions auth/users-file/users-file.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto/sha256"
"encoding/csv"
"encoding/hex"
"errors"
"io"
"os"
"strings"
Expand Down Expand Up @@ -96,3 +97,8 @@ func (a usersFileAuth) Authenticate(user, password string, expiresAt time.Time)

return nil, api.ErrInvalidAuthentication
}

func (a usersFileAuth) FindUser(clientID, provider string, expiresAt time.Time) (userID string, claims jwt.Claims, err error) {
err = errors.New("Not implemented yet")
return
}
18 changes: 8 additions & 10 deletions cmd/ag-companion-api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"flag"
"io/ioutil"
"log"
"net"
"net/http"
Expand All @@ -17,8 +16,7 @@ import (
companionapi "github.com/mcluseau/autentigo/pkg/companion-api/api"
"github.com/mcluseau/autentigo/pkg/companion-api/backend"
"github.com/mcluseau/autentigo/pkg/companion-api/backend/etcd"
"github.com/mcluseau/autentigo/pkg/companion-api/backend/users-file"
"github.com/mcluseau/autentigo/pkg/rbac"
usersfile "github.com/mcluseau/autentigo/pkg/companion-api/backend/users-file"
)

var (
Expand All @@ -36,14 +34,14 @@ func main() {

var err error

if rbac.Default, err = rbac.FromFile(*rbacFile); err != nil {
log.Fatal("failed to load RBAC rules: ", err)
}

if rbac.DefaultValidationCertificate, err = ioutil.ReadFile(*validationCrtPath); err != nil {
log.Fatal("failed to read validation certificate: ", err)
}
/* if rbac.Default, err = rbac.FromFile(*rbacFile); err != nil {
log.Fatal("failed to load RBAC rules: ", err)
}
if rbac.DefaultValidationCertificate, err = ioutil.ReadFile(*validationCrtPath); err != nil {
log.Fatal("failed to read validation certificate: ", err)
}
*/
cAPI := &companionapi.CompanionAPI{
Client: getBackEndClient(),
AdminToken: *adminToken,
Expand Down
Loading

0 comments on commit b1695a4

Please sign in to comment.