-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0b25d3c
commit e625649
Showing
18 changed files
with
782 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
package auth | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"log" | ||
"net/http" | ||
"os" | ||
"strings" | ||
"time" | ||
|
||
firebase "firebase.google.com/go" | ||
"firebase.google.com/go/auth" | ||
"github.com/jinzhu/gorm" | ||
) | ||
|
||
type ( | ||
// contextKey represents keys into a request context | ||
contextKey string | ||
// Token is a proxy of the Firebase Auth SDKs token, so that importing packages won't need to import both auth packages explicitly | ||
Token auth.Token | ||
|
||
// User is just a gorm model wrapper for the Firebase UID to support circle queries | ||
User struct { | ||
ID string `json:"id"` | ||
CircleID string `json:"circleId"` | ||
CreatedAt time.Time `json:"-"` | ||
UpdatedAt time.Time `json:"-"` | ||
DeletedAt *time.Time `json:"-"` | ||
} | ||
|
||
// Profile contains a user's profile information | ||
Profile struct { | ||
UID string `json:"uid"` | ||
Name string `json:"name"` | ||
ProfilePicture string `json:"profilePicture"` | ||
} | ||
) | ||
|
||
const ( | ||
// TestUID is the UID of a test user | ||
TestUID = "UpIEj9XrQNMzdOQDgPSY0MGSsnO2" | ||
|
||
googleApplicationCredentialsKey = "GOOGLE_APPLICATION_CREDENTIALS" | ||
|
||
// contextKeyAuthToken is the key in context under which the auth token is stored in AuthenticateRequest | ||
contextKeyAuthToken = contextKey("auth-token") | ||
) | ||
|
||
var firebaseApp *firebase.App | ||
|
||
// GetUser retrieves the UID from a request's token and returns the user model | ||
func GetUser(r *http.Request, db *gorm.DB) (User, error) { | ||
// Retrieve the auth token from the request context | ||
token, err := GetTokenFrom(r.Context()) | ||
if err != nil { | ||
return User{}, err | ||
} | ||
// Retrieve the client UID from the token | ||
currentUID := token.UID | ||
|
||
// Fetch the current user | ||
var user User | ||
if result := db.FirstOrCreate(&user, User{ID: currentUID}); result.Error != nil { | ||
return User{}, result.Error | ||
} | ||
return user, nil | ||
} | ||
|
||
// GetUserProfiles returns the users' profiles | ||
func GetUserProfiles(users ...User) ([]Profile, error) { | ||
ctx := context.Background() | ||
client, err := getFireBaseApp().Auth(ctx) | ||
if err != nil { | ||
return []Profile{}, fmt.Errorf("error getting firebase app: %v", err) | ||
} | ||
profiles := make([]Profile, len(users)) | ||
for i, u := range users { | ||
p := Profile{UID: u.ID} | ||
userRecord, err := client.GetUser(ctx, u.ID) | ||
if err == nil { | ||
p.Name = userRecord.DisplayName | ||
p.ProfilePicture = userRecord.PhotoURL | ||
} | ||
profiles[i] = p | ||
} | ||
return profiles, nil | ||
} | ||
|
||
// GetTokenFrom retrieves the access token from a request context. It returns an error if the token isn't found | ||
func GetTokenFrom(ctx context.Context) (*Token, error) { | ||
token := ctx.Value(contextKeyAuthToken) | ||
if token == nil { | ||
err := errors.New("Error: context doesn't contain a token") | ||
return nil, err | ||
} | ||
t := token.(*Token) | ||
return t, nil | ||
} | ||
|
||
// AddTokenTo adds an access token to a request context | ||
func AddTokenTo(ctx context.Context, token *Token) context.Context { | ||
return context.WithValue(ctx, contextKeyAuthToken, token) | ||
} | ||
|
||
// Middleware extracts the auth token from the request, verifies it, and returns an updated context with the validated auth token. | ||
// If the auth token is missing or invalid, the function writes a 403 Unauthorized status and error message to the response. | ||
func Middleware(next http.Handler) http.Handler { | ||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
// Retrieve the context from the request | ||
ctx := r.Context() | ||
// Retrieve the Authorization header from the request | ||
authHeader := r.Header.Get("Authorization") | ||
if authHeader == "" { | ||
err := errors.New("Missing auth token") | ||
log.Printf("Error verifying token: %v\n", err.Error()) | ||
http.Error(w, err.Error(), http.StatusUnauthorized) | ||
return | ||
} | ||
// Split the Authorization header (since format should be "Bearer [TOKEN]") | ||
splitAuthHeader := strings.Split(authHeader, " ") | ||
if len(splitAuthHeader) != 2 { | ||
err := errors.New("Invalid/malformed auth token") | ||
log.Printf("Error verifying token: %v\n", err.Error()) | ||
http.Error(w, err.Error(), http.StatusUnauthorized) | ||
return | ||
} | ||
// Retrieve the authorization token | ||
authToken := splitAuthHeader[1] | ||
|
||
// Get a reference to the Auth client | ||
client, err := getFireBaseApp().Auth(ctx) | ||
if err != nil { | ||
log.Fatalf("error getting Auth client: %v\n", err) | ||
} | ||
|
||
// Verify the auth token | ||
firebaseToken, err := client.VerifyIDTokenAndCheckRevoked(ctx, authToken) | ||
if err != nil { | ||
log.Printf("Error verifying token: %v\n", err.Error()) | ||
http.Error(w, "Error: Invalid auth token", http.StatusUnauthorized) | ||
return | ||
} | ||
|
||
token := Token(*firebaseToken) | ||
|
||
// Add the token to the context and pass it to the next handler | ||
ctx = context.WithValue(ctx, contextKeyAuthToken, &token) | ||
next.ServeHTTP(w, r.WithContext(ctx)) | ||
}) | ||
} | ||
|
||
func initializeApp() { | ||
|
||
if os.Getenv(googleApplicationCredentialsKey) == "" { | ||
log.Fatalf("auth: %v environment variable not set. Ensure you set %v to the path"+ | ||
" to your service account JSON", googleApplicationCredentialsKey, googleApplicationCredentialsKey) | ||
} | ||
var err error | ||
firebaseApp, err = firebase.NewApp(context.Background(), nil) | ||
if err != nil { | ||
log.Fatalf("error initializing Firebase app: %v\n", err) | ||
} | ||
} | ||
|
||
func getFireBaseApp() *firebase.App { | ||
if firebaseApp == nil { | ||
initializeApp() | ||
} | ||
return firebaseApp | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
package auth | ||
|
||
import ( | ||
"log" | ||
"net/http" | ||
"net/http/httptest" | ||
"os" | ||
"testing" | ||
|
||
"github.com/safe-distance/socium-infra/pkg/common" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestMain(m *testing.M) { | ||
if err := common.LoadEnv(); err != nil { | ||
log.Fatalln(err) | ||
} | ||
os.Exit(m.Run()) | ||
} | ||
|
||
func addTokenToRequest(r *http.Request, token string) { | ||
r.Header.Add("Authorization", "Bearer "+token) | ||
} | ||
|
||
func testMiddlewareHelper(t *testing.T, shouldSucceed bool, testToken func(validToken string) string) { | ||
// api key retrieved from https://console.firebase.google.com/u/0/project/safe-distance-e4683/settings/general | ||
token, err := GenerateToken(TestUID) | ||
assert.Nil(t, err) | ||
|
||
// Create a test Request with the ID token and a ResponseWriter | ||
r, err := http.NewRequest("", "", nil) | ||
assert.Nil(t, err) | ||
addTokenToRequest(r, testToken(token)) | ||
w := httptest.NewRecorder() | ||
|
||
handler := Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) | ||
|
||
// Call AuthenticateRequest with the test data | ||
handler.ServeHTTP(w, r) | ||
|
||
var correctStatus int | ||
if shouldSucceed { | ||
correctStatus = http.StatusOK | ||
} else { | ||
correctStatus = http.StatusUnauthorized | ||
} | ||
|
||
if w.Code != correctStatus { | ||
t.Fatalf("AuthorizationMiddleware test failed: response code was %v, should have been %v", w.Code, correctStatus) | ||
} | ||
} | ||
|
||
// TestAuthenticationMiddlewareValidToken tests the AuthenticateMiddleware with a valid ID Token and a dummy terminal handler | ||
func TestAuthenticationMiddlewareValidToken(t *testing.T) { | ||
testMiddlewareHelper(t, true, func(validToken string) string { | ||
return validToken | ||
}) | ||
} | ||
|
||
func TestAuthenticationMiddleWareNoToken(t *testing.T) { | ||
testMiddlewareHelper(t, false, func(validToken string) string { | ||
return "" | ||
}) | ||
} | ||
|
||
func TestAuthenticationMiddleWareInvalidToken(t *testing.T) { | ||
testMiddlewareHelper(t, false, func(validToken string) string { | ||
return "INVALID_TOKEN" | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
package auth | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"net/http" | ||
"os" | ||
) | ||
|
||
// GenerateToken generates a valid Firebase auth token | ||
func GenerateToken(uid string) (string, error) { | ||
|
||
// api key retrieved from https://console.firebase.google.com/u/0/project/safe-distance-e4683/settings/general | ||
apiKey := os.Getenv("GOOGLE_API_KEY") | ||
|
||
if apiKey == "" { | ||
return "", errors.New("Error: GOOGLE_API_KEY env variable not set") | ||
} | ||
|
||
ctx := context.Background() | ||
client, err := getFireBaseApp().Auth(ctx) | ||
if err != nil { | ||
return "", err | ||
} | ||
// Generate a custom token based off the uid | ||
customToken, err := client.CustomToken(ctx, uid) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
// Request an ID token using the custom token | ||
url := fmt.Sprintf("https://identitytoolkit.googleapis.com/v1/accounts:signInWithCustomToken?key=%s", apiKey) | ||
data := map[string]interface{}{"token": customToken, "returnSecureToken": true} | ||
payload, _ := json.Marshal(data) | ||
var res *http.Response | ||
if res, err = http.Post(url, "application/json", bytes.NewBuffer(payload)); err != nil { | ||
return "", err | ||
} | ||
|
||
if res.StatusCode != http.StatusOK { | ||
return "", errors.New("Error requesting auth token: request failed") | ||
} | ||
|
||
// Retrieve the ID token from the response | ||
var bodyData map[string]interface{} | ||
if err := json.NewDecoder(res.Body).Decode(&bodyData); err != nil { | ||
return "", err | ||
} | ||
token := bodyData["idToken"].(string) | ||
|
||
return token, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package common | ||
|
||
import ( | ||
"os" | ||
|
||
"github.com/joho/godotenv" | ||
) | ||
|
||
// LoadEnv loads the environment variables from the provided files. | ||
// If no filenames are passed, it loads variables from "./.env" by default. | ||
func LoadEnv() error { | ||
rootDir := os.Getenv("PROJECT_ROOT") | ||
if err := godotenv.Load(rootDir + "/.env"); err != nil { | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
type environmentVar string | ||
|
||
var ( | ||
dbProvider = environmentVar("DB_PROVIDER") | ||
dbConnectionString = environmentVar("DB_CONNECTION_STRING") | ||
) | ||
|
||
func (e environmentVar) String() string { | ||
return string(e) | ||
} |
Oops, something went wrong.