diff --git a/api/handler/v1beta1/appeal_test.go b/api/handler/v1beta1/appeal_test.go index fa3b837ac..c01b97e1e 100644 --- a/api/handler/v1beta1/appeal_test.go +++ b/api/handler/v1beta1/appeal_test.go @@ -128,11 +128,7 @@ func (s *GrpcHandlersSuite) TestListUserAppeals() { ResourceUrns: []string{"test-resource-urn"}, OrderBy: []string{"test-order"}, } - ctx := context.Background() - md := metadata.New(map[string]string{ - s.authenticatedUserHeaderKey: expectedUser, - }) - ctx = metadata.NewIncomingContext(ctx, md) + ctx := context.WithValue(context.Background(), authEmailTestContextKey{}, expectedUser) res, err := s.grpcServer.ListUserAppeals(ctx, req) s.NoError(err) @@ -162,11 +158,7 @@ func (s *GrpcHandlersSuite) TestListUserAppeals() { Return(nil, expectedError).Once() req := &guardianv1beta1.ListUserAppealsRequest{} - ctx := context.Background() - md := metadata.New(map[string]string{ - s.authenticatedUserHeaderKey: "test-user", - }) - ctx = metadata.NewIncomingContext(ctx, md) + ctx := context.WithValue(context.Background(), authEmailTestContextKey{}, "test-user") res, err := s.grpcServer.ListUserAppeals(ctx, req) s.Equal(codes.Internal, status.Code(err)) @@ -188,11 +180,7 @@ func (s *GrpcHandlersSuite) TestListUserAppeals() { Return(invalidAppeals, nil).Once() req := &guardianv1beta1.ListUserAppealsRequest{} - ctx := context.Background() - md := metadata.New(map[string]string{ - s.authenticatedUserHeaderKey: "test-user", - }) - ctx = metadata.NewIncomingContext(ctx, md) + ctx := context.WithValue(context.Background(), authEmailTestContextKey{}, "test-user") res, err := s.grpcServer.ListUserAppeals(ctx, req) s.Equal(codes.Internal, status.Code(err)) @@ -484,11 +472,7 @@ func (s *GrpcHandlersSuite) TestCreateAppeal() { }, Description: "The answer is 42", } - ctx := context.Background() - md := metadata.New(map[string]string{ - s.authenticatedUserHeaderKey: expectedUser, - }) - ctx = metadata.NewIncomingContext(ctx, md) + ctx := context.WithValue(context.Background(), authEmailTestContextKey{}, expectedUser) res, err := s.grpcServer.CreateAppeal(ctx, req) s.NoError(err) @@ -520,11 +504,7 @@ func (s *GrpcHandlersSuite) TestCreateAppeal() { s.appealService.EXPECT().Create(mock.AnythingOfType("*context.valueCtx"), mock.Anything).Return(appeal.ErrAppealDuplicate).Once() req := &guardianv1beta1.CreateAppealRequest{} - ctx := context.Background() - md := metadata.New(map[string]string{ - s.authenticatedUserHeaderKey: "user@example.com", - }) - ctx = metadata.NewIncomingContext(ctx, md) + ctx := context.WithValue(context.Background(), authEmailTestContextKey{}, "user@example.com") res, err := s.grpcServer.CreateAppeal(ctx, req) s.Equal(codes.AlreadyExists, status.Code(err)) @@ -539,11 +519,7 @@ func (s *GrpcHandlersSuite) TestCreateAppeal() { s.appealService.EXPECT().Create(mock.AnythingOfType("*context.valueCtx"), mock.Anything).Return(expectedError).Once() req := &guardianv1beta1.CreateAppealRequest{} - ctx := context.Background() - md := metadata.New(map[string]string{ - s.authenticatedUserHeaderKey: "user@example.com", - }) - ctx = metadata.NewIncomingContext(ctx, md) + ctx := context.WithValue(context.Background(), authEmailTestContextKey{}, "user@example.com") res, err := s.grpcServer.CreateAppeal(ctx, req) s.Equal(codes.Internal, status.Code(err)) @@ -567,11 +543,7 @@ func (s *GrpcHandlersSuite) TestCreateAppeal() { Return(nil).Once() req := &guardianv1beta1.CreateAppealRequest{Resources: make([]*guardianv1beta1.CreateAppealRequest_Resource, 1)} - ctx := context.Background() - md := metadata.New(map[string]string{ - s.authenticatedUserHeaderKey: "user@example.com", - }) - ctx = metadata.NewIncomingContext(ctx, md) + ctx := context.WithValue(context.Background(), authEmailTestContextKey{}, "user@example.com") res, err := s.grpcServer.CreateAppeal(ctx, req) s.Equal(codes.Internal, status.Code(err)) diff --git a/api/handler/v1beta1/approval_test.go b/api/handler/v1beta1/approval_test.go index 227d28592..57bb3660b 100644 --- a/api/handler/v1beta1/approval_test.go +++ b/api/handler/v1beta1/approval_test.go @@ -116,11 +116,7 @@ func (s *GrpcHandlersSuite) TestListUserApprovals() { Statuses: []string{"active", "pending"}, OrderBy: []string{"test-order"}, } - ctx := context.Background() - md := metadata.New(map[string]string{ - s.authenticatedUserHeaderKey: expectedUser, - }) - ctx = metadata.NewIncomingContext(ctx, md) + ctx := context.WithValue(context.Background(), authEmailTestContextKey{}, expectedUser) res, err := s.grpcServer.ListUserApprovals(ctx, req) s.NoError(err) @@ -150,11 +146,7 @@ func (s *GrpcHandlersSuite) TestListUserApprovals() { Return(nil, expectedError).Once() req := &guardianv1beta1.ListUserApprovalsRequest{} - ctx := context.Background() - md := metadata.New(map[string]string{ - s.authenticatedUserHeaderKey: "test-user", - }) - ctx = metadata.NewIncomingContext(ctx, md) + ctx := context.WithValue(context.Background(), authEmailTestContextKey{}, "test-user") res, err := s.grpcServer.ListUserApprovals(ctx, req) s.Equal(codes.Internal, status.Code(err)) @@ -178,11 +170,7 @@ func (s *GrpcHandlersSuite) TestListUserApprovals() { Return(invalidApprovals, nil).Once() req := &guardianv1beta1.ListUserApprovalsRequest{} - ctx := context.Background() - md := metadata.New(map[string]string{ - s.authenticatedUserHeaderKey: "test-user", - }) - ctx = metadata.NewIncomingContext(ctx, md) + ctx := context.WithValue(context.Background(), authEmailTestContextKey{}, "test-user") res, err := s.grpcServer.ListUserApprovals(ctx, req) s.Equal(codes.Internal, status.Code(err)) @@ -443,11 +431,7 @@ func (s *GrpcHandlersSuite) TestUpdateApproval() { Reason: expectedReason, }, } - ctx := context.Background() - md := metadata.New(map[string]string{ - s.authenticatedUserHeaderKey: expectedUser, - }) - ctx = metadata.NewIncomingContext(ctx, md) + ctx := context.WithValue(context.Background(), authEmailTestContextKey{}, expectedUser) res, err := s.grpcServer.UpdateApproval(ctx, req) s.NoError(err) @@ -553,11 +537,7 @@ func (s *GrpcHandlersSuite) TestUpdateApproval() { Return(nil, tc.expectedError).Once() req := &guardianv1beta1.UpdateApprovalRequest{} - ctx := context.Background() - md := metadata.New(map[string]string{ - s.authenticatedUserHeaderKey: expectedUser, - }) - ctx = metadata.NewIncomingContext(ctx, md) + ctx := context.WithValue(context.Background(), authEmailTestContextKey{}, expectedUser) res, err := s.grpcServer.UpdateApproval(ctx, req) s.Equal(tc.expectedStatusCode, status.Code(err)) @@ -579,11 +559,7 @@ func (s *GrpcHandlersSuite) TestUpdateApproval() { Return(invalidAppeal, nil).Once() req := &guardianv1beta1.UpdateApprovalRequest{} - ctx := context.Background() - md := metadata.New(map[string]string{ - s.authenticatedUserHeaderKey: "user@example.com", - }) - ctx = metadata.NewIncomingContext(ctx, md) + ctx := context.WithValue(context.Background(), authEmailTestContextKey{}, "user@example.com") res, err := s.grpcServer.UpdateApproval(ctx, req) s.Equal(codes.Internal, status.Code(err)) diff --git a/api/handler/v1beta1/grpc.go b/api/handler/v1beta1/grpc.go index afe66ea75..8aed90424 100644 --- a/api/handler/v1beta1/grpc.go +++ b/api/handler/v1beta1/grpc.go @@ -2,14 +2,15 @@ package v1beta1 import ( "context" - "errors" + "strings" "github.com/odpf/guardian/core/appeal" "github.com/odpf/guardian/core/grant" guardianv1beta1 "github.com/odpf/guardian/api/proto/odpf/guardian/v1beta1" "github.com/odpf/guardian/domain" - "google.golang.org/grpc/metadata" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type ProtoAdapter interface { @@ -116,7 +117,7 @@ type GRPCServer struct { grantService grantService adapter ProtoAdapter - authenticatedUserHeaderKey string + authenticatedUserContextKey interface{} guardianv1beta1.UnimplementedGuardianServiceServer } @@ -130,32 +131,30 @@ func NewGRPCServer( approvalService approvalService, grantService grantService, adapter ProtoAdapter, - authenticatedUserHeaderKey string, + authenticatedUserContextKey interface{}, ) *GRPCServer { return &GRPCServer{ - resourceService: resourceService, - activityService: activityService, - providerService: providerService, - policyService: policyService, - appealService: appealService, - approvalService: approvalService, - grantService: grantService, - adapter: adapter, - authenticatedUserHeaderKey: authenticatedUserHeaderKey, + resourceService: resourceService, + activityService: activityService, + providerService: providerService, + policyService: policyService, + appealService: appealService, + approvalService: approvalService, + grantService: grantService, + adapter: adapter, + authenticatedUserContextKey: authenticatedUserContextKey, } } func (s *GRPCServer) getUser(ctx context.Context) (string, error) { - md, ok := metadata.FromIncomingContext(ctx) + authenticatedEmail, ok := ctx.Value(s.authenticatedUserContextKey).(string) if !ok { - return "", errors.New("unable to retrieve metadata from context") + return "", status.Error(codes.Unauthenticated, "unable to get authenticated user from context") } - users := md.Get(s.authenticatedUserHeaderKey) - if len(users) == 0 { - return "", errors.New("user email not found") + if strings.TrimSpace(authenticatedEmail) == "" { + return "", status.Error(codes.Unauthenticated, "unable to get authenticated user from context") } - currentUser := users[0] - return currentUser, nil + return authenticatedEmail, nil } diff --git a/api/handler/v1beta1/grpc_test.go b/api/handler/v1beta1/grpc_test.go index 72e68234f..a223e7b5e 100644 --- a/api/handler/v1beta1/grpc_test.go +++ b/api/handler/v1beta1/grpc_test.go @@ -8,6 +8,8 @@ import ( "github.com/stretchr/testify/suite" ) +type authEmailTestContextKey struct{} + type GrpcHandlersSuite struct { suite.Suite @@ -19,8 +21,6 @@ type GrpcHandlersSuite struct { approvalService *mocks.ApprovalService grantService *mocks.GrantService grpcServer *v1beta1.GRPCServer - - authenticatedUserHeaderKey string } func TestGrpcHandler(t *testing.T) { @@ -35,7 +35,6 @@ func (s *GrpcHandlersSuite) setup() { s.appealService = new(mocks.AppealService) s.approvalService = new(mocks.ApprovalService) s.grantService = new(mocks.GrantService) - s.authenticatedUserHeaderKey = "test-header-key" s.grpcServer = v1beta1.NewGRPCServer( s.resourceService, s.activityService, @@ -45,6 +44,6 @@ func (s *GrpcHandlersSuite) setup() { s.approvalService, s.grantService, v1beta1.NewAdapter(), - s.authenticatedUserHeaderKey, + authEmailTestContextKey{}, ) } diff --git a/internal/server/auth.go b/internal/server/auth.go index 07c45bedb..915feff0e 100644 --- a/internal/server/auth.go +++ b/internal/server/auth.go @@ -11,19 +11,29 @@ import ( type authenticatedUserEmailContextKey struct{} +var logrusActorKey = "actor" + func withAuthenticatedUserEmail(headerKey string) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { if md, ok := metadata.FromIncomingContext(ctx); ok { if v := md.Get(headerKey); len(v) > 0 { userEmail := v[0] ctx = context.WithValue(ctx, authenticatedUserEmailContextKey{}, userEmail) - - ctx_logrus.AddFields(ctx, logrus.Fields{ - headerKey: userEmail, - }) } } return handler(ctx, req) } } + +func withLogrusContext() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if userEmail, ok := ctx.Value(authenticatedUserEmailContextKey{}).(string); ok { + ctx_logrus.AddFields(ctx, logrus.Fields{ + logrusActorKey: userEmail, + }) + } + + return handler(ctx, req) + } +} diff --git a/internal/server/config.go b/internal/server/config.go index 0caa423a1..1e4490b24 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/odpf/guardian/internal/store" + "github.com/odpf/guardian/pkg/auth" "github.com/odpf/guardian/pkg/tracing" "github.com/odpf/guardian/plugins/notifiers" "github.com/odpf/salt/config" @@ -33,16 +34,28 @@ type Jobs struct { ExpiringAccessNotification JobConfig `mapstructure:"expiring_access_notification"` } +type DefaultAuth struct { + HeaderKey string `mapstructure:"header_key" default:"X-Auth-Email"` +} + +type Auth struct { + Provider string `mapstructure:"provider" default:"default"` + Default DefaultAuth `mapstructure:"default"` + OIDC auth.OIDCAuth `mapstructure:"oidc"` +} + type Config struct { - Port int `mapstructure:"port" default:"8080"` - EncryptionSecretKeyKey string `mapstructure:"encryption_secret_key"` - Notifier notifiers.Config `mapstructure:"notifier"` - LogLevel string `mapstructure:"log_level" default:"info"` - DB store.Config `mapstructure:"db"` - AuthenticatedUserHeaderKey string `mapstructure:"authenticated_user_header_key"` - AuditLogTraceIDHeaderKey string `mapstructure:"audit_log_trace_id_header_key" default:"X-Trace-Id"` - Jobs Jobs `mapstructure:"jobs"` - Telemetry tracing.Config `mapstructure:"telemetry"` + Port int `mapstructure:"port" default:"8080"` + EncryptionSecretKeyKey string `mapstructure:"encryption_secret_key"` + Notifier notifiers.Config `mapstructure:"notifier"` + LogLevel string `mapstructure:"log_level" default:"info"` + DB store.Config `mapstructure:"db"` + // Deprecated: use Auth.Default.HeaderKey instead note on the AuthenticatedUserHeaderKey + AuthenticatedUserHeaderKey string `mapstructure:"authenticated_user_header_key"` + AuditLogTraceIDHeaderKey string `mapstructure:"audit_log_trace_id_header_key" default:"X-Trace-Id"` + Jobs Jobs `mapstructure:"jobs"` + Telemetry tracing.Config `mapstructure:"telemetry"` + Auth Auth `mapstructure:"auth"` } func LoadConfig(configFile string) (Config, error) { @@ -56,5 +69,11 @@ func LoadConfig(configFile string) (Config, error) { } return Config{}, err } + + // keep for backward-compatibility + if cfg.AuthenticatedUserHeaderKey != "" { + cfg.Auth.Default.HeaderKey = cfg.AuthenticatedUserHeaderKey + } + return cfg, nil } diff --git a/internal/server/config.yaml b/internal/server/config.yaml index c7b78059b..f2250e925 100644 --- a/internal/server/config.yaml +++ b/internal/server/config.yaml @@ -11,7 +11,7 @@ PORT: 3000 ENCRYPTION_SECRET_KEY: -AUTHENTICATED_USER_HEADER_KEY: X-User-Email +AUTHENTICATED_USER_HEADER_KEY: X-Auth-Email LOG: LEVEL: info DB: @@ -42,4 +42,11 @@ TELEMETRY: OTLP: HEADERS: api-key: - ENDPOINT: "otlp.nr-data.net:4317" \ No newline at end of file + ENDPOINT: "otlp.nr-data.net:4317" +AUTH: + PROVIDER: default # can be "default" or "oidc" + DEFAULT: + HEADER_KEY: X-Auth-Email # AUTHENTICATED_USER_HEADER_KEY takes priority for backward-compatibility + OIDC: + AUDIENCE: "some-kind-of-audience.com" + ELIGIBLE_EMAIL_DOMAINS: "emaildomain1.com,emaildomain2.com" \ No newline at end of file diff --git a/internal/server/server.go b/internal/server/server.go index b37efe7a1..b9ff34241 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -20,6 +20,7 @@ import ( guardianv1beta1 "github.com/odpf/guardian/api/proto/odpf/guardian/v1beta1" "github.com/odpf/guardian/internal/store/postgres" "github.com/odpf/guardian/jobs" + "github.com/odpf/guardian/pkg/auth" "github.com/odpf/guardian/pkg/crypto" "github.com/odpf/guardian/pkg/scheduler" "github.com/odpf/guardian/pkg/tracing" @@ -30,6 +31,7 @@ import ( "github.com/sirupsen/logrus" "github.com/uptrace/opentelemetry-go-extra/otelgorm" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "google.golang.org/api/idtoken" "google.golang.org/grpc" "google.golang.org/protobuf/encoding/protojson" ) @@ -104,6 +106,12 @@ func RunServer(config *Config) error { // init grpc server logrusEntry := logrus.NewEntry(logrus.New()) // TODO: get logrus instance from `logger` var + + authInterceptor, err := getAuthInterceptor(config) + if err != nil { + return err + } + grpcServer := grpc.NewServer( grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( grpc_logrus.StreamServerInterceptor(logrusEntry), @@ -117,10 +125,17 @@ func RunServer(config *Config) error { }), ), grpc_logrus.UnaryServerInterceptor(logrusEntry), - withAuthenticatedUserEmail(config.AuthenticatedUserHeaderKey), + authInterceptor, + withLogrusContext(), otelgrpc.UnaryServerInterceptor(), )), ) + + authUserContextKey := map[string]interface{}{ + "default": authenticatedUserEmailContextKey{}, + "oidc": auth.OIDCEmailContextKey{}, + } + protoAdapter := handlerv1beta1.NewAdapter() guardianv1beta1.RegisterGuardianServiceServer(grpcServer, handlerv1beta1.NewGRPCServer( services.ResourceService, @@ -131,7 +146,7 @@ func RunServer(config *Config) error { services.ApprovalService, services.GrantService, protoAdapter, - config.AuthenticatedUserHeaderKey, + authUserContextKey[config.Auth.Provider], )) // init http proxy @@ -218,7 +233,7 @@ func makeHeaderMatcher(c *Config) func(key string) (string, bool) { return func(key string) (string, bool) { switch strings.ToLower(key) { case - strings.ToLower(c.AuthenticatedUserHeaderKey), + strings.ToLower(c.Auth.Default.HeaderKey), strings.ToLower(c.AuditLogTraceIDHeaderKey): return key, true default: @@ -266,3 +281,20 @@ func fetchDefaultJobScheduleMapping() map[JobType]string { ExpiringGrantNotification: "0 9 * * *", } } + +func getAuthInterceptor(config *Config) (grpc.UnaryServerInterceptor, error) { + // default fallback to user email on header + authInterceptor := withAuthenticatedUserEmail(config.Auth.Default.HeaderKey) + + if config.Auth.Provider == "oidc" { + idtokenValidator, err := idtoken.NewValidator(context.Background()) + if err != nil { + return nil, err + } + + bearerTokenValidator := auth.NewOIDCValidator(idtokenValidator, config.Auth.OIDC) + authInterceptor = bearerTokenValidator.WithOIDCValidator() + } + + return authInterceptor, nil +} diff --git a/internal/server/services.go b/internal/server/services.go index dc74b0df2..ef2007f9d 100644 --- a/internal/server/services.go +++ b/internal/server/services.go @@ -17,6 +17,7 @@ import ( "github.com/odpf/guardian/core/resource" "github.com/odpf/guardian/domain" "github.com/odpf/guardian/internal/store/postgres" + "github.com/odpf/guardian/pkg/auth" "github.com/odpf/guardian/plugins/identities" "github.com/odpf/guardian/plugins/notifiers" "github.com/odpf/guardian/plugins/providers/bigquery" @@ -66,6 +67,8 @@ func InitServices(deps ServiceDeps) (*Services, error) { auditRepository := audit_repos.NewPostgresRepository(sqldb) auditRepository.Init(context.TODO()) + actorExtractor := getActorExtractor(deps.Config) + auditLogger := audit.New( audit.WithRepository(auditRepository), audit.WithMetadataExtractor(func(ctx context.Context) map[string]interface{} { @@ -88,12 +91,7 @@ func InitServices(deps ServiceDeps) (*Services, error) { return md }), - audit.WithActorExtractor(func(ctx context.Context) (string, error) { - if actor, ok := ctx.Value(authenticatedUserEmailContextKey{}).(string); ok { - return actor, nil - } - return "", nil - }), + actorExtractor, ) activityRepository := postgres.NewActivityRepository(store.DB()) @@ -184,3 +182,19 @@ func InitServices(deps ServiceDeps) (*Services, error) { grantService, }, nil } + +func getActorExtractor(config *Config) audit.AuditOption { + var contextKey interface{} + + contextKey = authenticatedUserEmailContextKey{} + if config.Auth.Provider == "oidc" { + contextKey = auth.OIDCEmailContextKey{} + } + + return audit.WithActorExtractor(func(ctx context.Context) (string, error) { + if actor, ok := ctx.Value(contextKey).(string); ok { + return actor, nil + } + return "", nil + }) +} diff --git a/pkg/auth/mocks/OIDCValidator.go b/pkg/auth/mocks/OIDCValidator.go new file mode 100644 index 000000000..62bb30ba3 --- /dev/null +++ b/pkg/auth/mocks/OIDCValidator.go @@ -0,0 +1,53 @@ +// Code generated by mockery v2.16.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + idtoken "google.golang.org/api/idtoken" +) + +// OIDCValidator is an autogenerated mock type for the Validator type +type OIDCValidator struct { + mock.Mock +} + +// Validate provides a mock function with given fields: ctx, token, audience +func (_m *OIDCValidator) Validate(ctx context.Context, token string, audience string) (*idtoken.Payload, error) { + ret := _m.Called(ctx, token, audience) + + var r0 *idtoken.Payload + if rf, ok := ret.Get(0).(func(context.Context, string, string) *idtoken.Payload); ok { + r0 = rf(ctx, token, audience) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*idtoken.Payload) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, token, audience) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type mockConstructorTestingTNewOIDCValidator interface { + mock.TestingT + Cleanup(func()) +} + +// NewOIDCValidator creates a new instance of OIDCValidator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewOIDCValidator(t mockConstructorTestingTNewOIDCValidator) *OIDCValidator { + mock := &OIDCValidator{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/auth/oidc.go b/pkg/auth/oidc.go new file mode 100644 index 000000000..19da043dc --- /dev/null +++ b/pkg/auth/oidc.go @@ -0,0 +1,99 @@ +package auth + +import ( + "context" + "strings" + + "google.golang.org/api/idtoken" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +var InvalidAuthError = status.Errorf(codes.Unauthenticated, "invalid authentication credentials") + +type Validator interface { + Validate(ctx context.Context, token string, audience string) (*idtoken.Payload, error) +} + +type OIDCEmailContextKey struct{} + +type OIDCAuth struct { + Audience string `mapstructure:"audience"` + EligibleEmailDomains string `mapstructure:"eligible_email_domains"` +} + +type OIDCValidator struct { + validator Validator + audience string + validEmailDomains []string +} + +func NewOIDCValidator(validator Validator, config OIDCAuth) *OIDCValidator { + audience := config.Audience + + var validEmailDomains []string + if strings.TrimSpace(config.EligibleEmailDomains) != "" { + validEmailDomains = strings.Split(config.EligibleEmailDomains, ",") + } + + return &OIDCValidator{ + validator: validator, + audience: audience, + validEmailDomains: validEmailDomains, + } +} + +func (v *OIDCValidator) WithOIDCValidator() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, InvalidAuthError + } + + headerValue := md.Get("authorization") + if len(headerValue) == 0 || strings.TrimSpace(headerValue[0]) == "" { + return nil, InvalidAuthError + } + + bearerToken := strings.TrimSpace(strings.TrimPrefix(headerValue[0], "Bearer ")) + if len(bearerToken) == 0 { + return nil, InvalidAuthError + } + + payload, err := v.validator.Validate(ctx, bearerToken, v.audience) + if err != nil { + return nil, InvalidAuthError + } + + email := payload.Claims["email"].(string) + if err := v.validateEmailDomain(email); err != nil { + return nil, err + } + + ctx = context.WithValue(ctx, OIDCEmailContextKey{}, email) + + return handler(ctx, req) + } +} + +func (v *OIDCValidator) validateEmailDomain(email string) error { + // no valid email domains listed means that no email domain will be checked + if len(v.validEmailDomains) == 0 { + return nil + } + + emailDomainMatch := false + for _, validEmailDomain := range v.validEmailDomains { + if strings.HasSuffix(email, "@"+validEmailDomain) { + emailDomainMatch = true + break + } + } + + if !emailDomainMatch { + return InvalidAuthError + } + return nil +} diff --git a/pkg/auth/oidc_test.go b/pkg/auth/oidc_test.go new file mode 100644 index 000000000..273fe979d --- /dev/null +++ b/pkg/auth/oidc_test.go @@ -0,0 +1,155 @@ +package auth_test + +import ( + "context" + "errors" + "testing" + + "github.com/odpf/guardian/pkg/auth" + "github.com/odpf/guardian/pkg/auth/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "google.golang.org/api/idtoken" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +var authContextValues = map[string]string{ + "Authorization": "Bearer some-bearer-token-in-JWT", +} + +type InterceptorTestSuite struct { + suite.Suite +} + +func (s *InterceptorTestSuite) TestIdTokenValidator_WithBearerTokenValidator() { + emptyAuthContextValues := map[string]string{ + "Authorization": "Bearer ", + } + + testCases := []struct { + name string + params auth.OIDCAuth + ctx context.Context + mockFunc func(validator *mocks.OIDCValidator) + expectedErr error + }{ + { + name: "MD context value does not exist", + params: auth.OIDCAuth{}, + ctx: context.Background(), + mockFunc: func(validator *mocks.OIDCValidator) {}, + expectedErr: auth.InvalidAuthError, + }, + { + name: "empty authorization header", + params: auth.OIDCAuth{}, + ctx: metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{})), + mockFunc: func(validator *mocks.OIDCValidator) {}, + expectedErr: auth.InvalidAuthError, + }, + { + name: "empty bearer token on authorization header", + params: auth.OIDCAuth{}, + ctx: metadata.NewIncomingContext(context.Background(), metadata.New(emptyAuthContextValues)), + mockFunc: func(validator *mocks.OIDCValidator) {}, + expectedErr: auth.InvalidAuthError, + }, + { + name: "error while validating token", + params: auth.OIDCAuth{ + Audience: "google.com", + }, + ctx: metadata.NewIncomingContext(context.Background(), metadata.New(authContextValues)), + mockFunc: func(validator *mocks.OIDCValidator) { + validator.On("Validate", mock.Anything, mock.Anything, "google.com"). + Return(nil, errors.New("something happened")) + }, + expectedErr: auth.InvalidAuthError, + }, + { + name: "email domain does not match with eligible domains", + params: auth.OIDCAuth{ + Audience: "google.com", + EligibleEmailDomains: "example.com,something.org", + }, + ctx: metadata.NewIncomingContext(context.Background(), metadata.New(authContextValues)), + mockFunc: func(validator *mocks.OIDCValidator) { + + payload := &idtoken.Payload{ + Claims: map[string]interface{}{ + "email": "something@gmail.com", + }, + } + validator.On("Validate", mock.Anything, mock.Anything, "google.com"). + Return(payload, nil) + }, + expectedErr: auth.InvalidAuthError, + }, + { + name: "successful request with matching eligible email domains", + params: auth.OIDCAuth{ + Audience: "google.com", + EligibleEmailDomains: "example.com,something.org", + }, + ctx: metadata.NewIncomingContext(context.Background(), metadata.New(authContextValues)), + mockFunc: func(validator *mocks.OIDCValidator) { + payload := &idtoken.Payload{ + Claims: map[string]interface{}{ + "email": "something@example.com", + }, + } + validator.On("Validate", mock.Anything, mock.Anything, "google.com"). + Return(payload, nil) + }, + expectedErr: nil, + }, + { + name: "successful request with no eligible email domains configurations whatsoever", + params: auth.OIDCAuth{ + Audience: "google.com", + }, + ctx: metadata.NewIncomingContext(context.Background(), metadata.New(authContextValues)), + mockFunc: func(validator *mocks.OIDCValidator) { + payload := &idtoken.Payload{ + Claims: map[string]interface{}{ + "email": "something@example.com", + }, + } + validator.On("Validate", mock.Anything, mock.Anything, "google.com"). + Return(payload, nil) + }, + expectedErr: nil, + }, + } + + var req interface{} + + for _, tc := range testCases { + s.Run(tc.name, func() { + validator := new(mocks.OIDCValidator) + authValidator := auth.NewOIDCValidator(validator, tc.params) + interceptFunc := authValidator.WithOIDCValidator() + + tc.mockFunc(validator) + result, err := interceptFunc(tc.ctx, req, &grpc.UnaryServerInfo{}, s.unaryDummyHandler) + + assert.Nil(s.T(), result) + assert.Equal(s.T(), tc.expectedErr, err) + }) + } +} + +func (suite *InterceptorTestSuite) unaryDummyHandler(ctx context.Context, _ interface{}) (interface{}, error) { + expectedCtx := metadata.NewIncomingContext(context.Background(), metadata.New(authContextValues)) + expectedCtx = context.WithValue(expectedCtx, auth.OIDCEmailContextKey{}, "something@example.com") + + assert.Equal(suite.T(), expectedCtx, ctx, "final method handler doesn't have matching context") + + return nil, nil +} + +func TestOidcValidatorInterceptor(t *testing.T) { + suite.Run(t, new(InterceptorTestSuite)) +}