diff --git a/cmd/auth-rest/startcmd/start.go b/cmd/auth-rest/startcmd/start.go index 2dc2eae..7278f3b 100644 --- a/cmd/auth-rest/startcmd/start.go +++ b/cmd/auth-rest/startcmd/start.go @@ -79,6 +79,11 @@ const ( " For CouchDB, include the username:password@ text if required." + " Alternatively, this can be set with the following environment variable: " + databaseURLEnvKey + staticFilesPathFlagName = "static-path" + staticFilesPathFlagUsage = "Path to the folder where the static files are to be hosted under " + uiEndpoint + "." + + "Alternatively, this can be set with the following environment variable: " + staticFilesPathEnvKey + staticFilesPathEnvKey = "AUTH_REST_STATIC_FILES" + databasePrefixFlagName = "database-prefix" databasePrefixEnvKey = "AUTH_REST_DATABASE_PREFIX" databasePrefixFlagShorthand = "p" @@ -129,6 +134,7 @@ const ( const ( // api + uiEndpoint = "/ui" healthCheckEndpoint = "/healthcheck" ) @@ -143,6 +149,7 @@ type authRestParameters struct { tlsParams *tlsParams oidcParams *oidcParams bootstrapParams *bootstrapParams + staticFiles string } type tlsParams struct { @@ -210,7 +217,7 @@ func createStartCmd(srv server) *cobra.Command { } } -func getAuthRestParameters(cmd *cobra.Command) (*authRestParameters, error) { +func getAuthRestParameters(cmd *cobra.Command) (*authRestParameters, error) { //nolint:funlen,gocyclo hostURL, err := cmdutils.GetUserSetVarFromString(cmd, hostURLFlagName, hostURLEnvKey, false) if err != nil { return nil, err @@ -231,6 +238,11 @@ func getAuthRestParameters(cmd *cobra.Command) (*authRestParameters, error) { return nil, err } + staticFiles, err := cmdutils.GetUserSetVarFromString(cmd, staticFilesPathFlagName, staticFilesPathEnvKey, true) + if err != nil { + return nil, err + } + var databaseURL string if databaseType == databaseTypeMemOption { databaseURL = "N/A" @@ -266,6 +278,7 @@ func getAuthRestParameters(cmd *cobra.Command) (*authRestParameters, error) { databasePrefix: databasePrefix, oidcParams: oidcParams, bootstrapParams: bootstrapParams, + staticFiles: staticFiles, }, nil } @@ -311,6 +324,7 @@ func createFlags(startCmd *cobra.Command) { startCmd.Flags().StringP(tlsServeCertPathFlagName, "", "", tlsServeCertPathFlagUsage) startCmd.Flags().StringP(tlsServeKeyPathFlagName, "", "", tlsServeKeyPathFlagUsage) startCmd.Flags().StringP(logLevelFlagName, logLevelFlagShorthand, "", logLevelPrefixFlagUsage) + startCmd.Flags().StringP(staticFilesPathFlagName, "", "", staticFilesPathFlagUsage) startCmd.Flags().StringP(databaseTypeFlagName, databaseTypeFlagShorthand, "", databaseTypeFlagUsage) startCmd.Flags().StringP(databaseURLFlagName, databaseURLFlagShorthand, "", databaseURLFlagUsage) startCmd.Flags().StringP(databasePrefixFlagName, databasePrefixFlagShorthand, "", databasePrefixFlagUsage) @@ -340,7 +354,6 @@ func startAuthService(parameters *authRestParameters, srv server) error { logger.Debugf("root ca's %v", rootCAs) router := mux.NewRouter() - // health check router.HandleFunc(healthCheckEndpoint, healthCheckHandler).Methods(http.MethodGet) @@ -368,12 +381,16 @@ func startAuthService(parameters *authRestParameters, srv server) error { router.HandleFunc(handler.Path(), handler.Handle()).Methods(handler.Method()) } - logger.Infof(`Starting hub-auth REST server with the following parameters: -Host URL: %s -Database type: %s + logger.Infof(`Starting hub-auth REST server with the following parameters:Host URL: %s Database type: %s Database URL: %s Database prefix: %s`, parameters.hostURL, parameters.databaseType, parameters.databaseURL, parameters.databasePrefix) + // static frontend + router.PathPrefix(uiEndpoint). + Subrouter(). + Methods(http.MethodGet). + HandlerFunc(uiHandler(parameters.staticFiles, http.ServeFile)) + return srv.ListenAndServeTLS( parameters.hostURL, parameters.tlsParams.serveCertPath, @@ -382,6 +399,19 @@ Database prefix: %s`, parameters.hostURL, parameters.databaseType, parameters.da ) } +func uiHandler( + basePath string, + fileServer func(http.ResponseWriter, *http.Request, string)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == uiEndpoint { + fileServer(w, r, strings.ReplaceAll(basePath+"/index.html", "//", "/")) + return + } + + fileServer(w, r, strings.ReplaceAll(basePath+"/"+r.URL.Path[len(uiEndpoint):], "//", "/")) + } +} + func getOIDCParams(cmd *cobra.Command) (*oidcParams, error) { params := &oidcParams{} diff --git a/cmd/auth-rest/startcmd/start_test.go b/cmd/auth-rest/startcmd/start_test.go index 3294fc3..15d2d09 100644 --- a/cmd/auth-rest/startcmd/start_test.go +++ b/cmd/auth-rest/startcmd/start_test.go @@ -11,6 +11,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "os" "strings" "testing" @@ -197,6 +198,26 @@ func TestStartCmdWithBlankEnvVar(t *testing.T) { }) } +func TestUIHandler(t *testing.T) { + t.Run("handle base path", func(t *testing.T) { + handled := false + uiHandler(uiEndpoint, func(_ http.ResponseWriter, _ *http.Request, path string) { + handled = true + require.Equal(t, uiEndpoint+"/index.html", path) + })(nil, &http.Request{URL: &url.URL{Path: uiEndpoint}}) + require.True(t, handled) + }) + t.Run("handle subpaths", func(t *testing.T) { + const expected = uiEndpoint + "/css/abc123.css" + handled := false + uiHandler(uiEndpoint, func(_ http.ResponseWriter, _ *http.Request, path string) { + handled = true + require.Equal(t, expected, path) + })(nil, &http.Request{URL: &url.URL{Path: expected}}) + require.True(t, handled) + }) +} + func TestStartCmdValidArgs(t *testing.T) { t.Run("In-memory storage, valid log level", func(t *testing.T) { oidcURL := mockOIDCProvider(t) diff --git a/pkg/restapi/operation/operations.go b/pkg/restapi/operation/operations.go index 1aa1f80..e88cde6 100644 --- a/pkg/restapi/operation/operations.go +++ b/pkg/restapi/operation/operations.go @@ -13,6 +13,7 @@ import ( "errors" "fmt" "net/http" + "net/url" "github.com/coreos/go-oidc" "github.com/google/uuid" @@ -33,6 +34,9 @@ const ( transientStoreName = "hub-auth-rest-transient" bootstrapStoreName = "bootstrap-data" + + // redirect url parameter + userProfileQueryParam = "up" ) var logger = log.New("hub-auth-restapi") @@ -116,6 +120,7 @@ type Operation struct { oidcClientID string oidcClientSecret string oidcCallbackURL string + uiEndpoint string oauth2ConfigFunc func(...string) oauth2Config bootstrapStore storage.Store bootstrapConfig *BootstrapConfig @@ -129,6 +134,7 @@ type Config struct { OIDCClientID string OIDCClientSecret string OIDCCallbackURL string + UIEndpoint string TransientStoreProvider storage.Provider StoreProvider storage.Provider BootstrapConfig *BootstrapConfig @@ -329,7 +335,15 @@ func (c *Operation) handleOIDCCallback(w http.ResponseWriter, r *http.Request) { return } - handleAuthResult(w, r, userProfile) + profileBytes, err := json.Marshal(userProfile) + if err != nil { + c.writeErrorResponse(w, http.StatusInternalServerError, + fmt.Sprintf("failed to marshal user profile data : %s", err)) + + return + } + + c.handleAuthResult(w, r, profileBytes) } // TODO onboard user at key server and SDS: https://github.com/trustbloc/hub-auth/issues/38 @@ -347,7 +361,7 @@ func (c *Operation) onboardUser(id string) (*user.Profile, error) { } func (c *Operation) handleBootstrapDataRequest(w http.ResponseWriter, r *http.Request) { - handle := r.URL.Query().Get("up") + handle := r.URL.Query().Get(userProfileQueryParam) if handle == "" { handleAuthError(w, http.StatusBadRequest, "missing handle") @@ -388,9 +402,21 @@ func (c *Operation) handleBootstrapDataRequest(w http.ResponseWriter, r *http.Re } } -// TODO redirect to the UI: https://github.com/trustbloc/hub-auth/issues/39 -func handleAuthResult(w http.ResponseWriter, r *http.Request, _ *user.Profile) { - http.Redirect(w, r, "", http.StatusFound) +func (c *Operation) handleAuthResult(w http.ResponseWriter, r *http.Request, profileBytes []byte) { + handle := url.QueryEscape(uuid.New().String()) + + err := c.transientStore.Put(handle, profileBytes) + if err != nil { + c.writeErrorResponse(w, + http.StatusInternalServerError, fmt.Sprintf("failed to write handle to transient store: %s", err)) + + return + } + + redirectURL := fmt.Sprintf("%s?%s=%s", c.uiEndpoint, userProfileQueryParam, handle) + + http.Redirect(w, r, redirectURL, http.StatusFound) + logger.Debugf("redirected to: %s", redirectURL) } func handleAuthError(w http.ResponseWriter, status int, msg string) { diff --git a/pkg/restapi/operation/operations_test.go b/pkg/restapi/operation/operations_test.go index 5a5d63e..3914e92 100644 --- a/pkg/restapi/operation/operations_test.go +++ b/pkg/restapi/operation/operations_test.go @@ -411,6 +411,30 @@ func TestHandleOIDCCallback(t *testing.T) { svc.handleOIDCCallback(result, newOIDCCallback(state, "code")) require.Equal(t, http.StatusInternalServerError, result.Code) }) + t.Run("PUT error while storing user info while handling callback user", func(t *testing.T) { + id := uuid.New().String() + state := uuid.New().String() + config := config(t) + + config.TransientStoreProvider = &mockstorage.Provider{ + Stores: map[string]storage.Store{ + transientStoreName: &mockstore.MockStore{ + Store: map[string][]byte{ + id: []byte("{}"), + }, + ErrGet: storage.ErrValueNotFound, + ErrPut: errors.New("generic"), + }, + }, + } + + svc, err := New(config) + require.NoError(t, err) + + result := httptest.NewRecorder() + svc.handleAuthResult(result, newOIDCCallback(state, "code"), nil) + require.Equal(t, http.StatusInternalServerError, result.Code) + }) } func TestHandleBootstrapDataRequest(t *testing.T) { @@ -487,7 +511,7 @@ func newOIDCCallback(state, code string) *http.Request { func newBootstrapDataRequest(handle string) *http.Request { return httptest.NewRequest(http.MethodGet, - fmt.Sprintf("http://example.com/bootstrap?up=%s", handle), nil) + fmt.Sprintf("http://example.com/bootstrap?%s=%s", userProfileQueryParam,handle), nil) } type mockOIDCProvider struct {