From 7343852a664c8841ee23ed8106fd89d0e0dcb64b Mon Sep 17 00:00:00 2001 From: sethsec-bf <46326948+sethsec-bf@users.noreply.github.com> Date: Thu, 21 Dec 2023 11:22:10 -0500 Subject: [PATCH] Updated endpoints to use the newly created sdk functions --- aws/endpoints.go | 439 ++++++++++++++++------------------------------- 1 file changed, 149 insertions(+), 290 deletions(-) diff --git a/aws/endpoints.go b/aws/endpoints.go index 37b8f1c..1fadc66 100644 --- a/aws/endpoints.go +++ b/aws/endpoints.go @@ -930,10 +930,6 @@ func (m *EndpointsModule) getAPIGatewayVIPsPerRegion(r string, wg *sync.WaitGrou // m.CommandCounter.Total++ m.CommandCounter.Pending-- m.CommandCounter.Executing++ - // "PaginationMarker" is a control variable used for output continuity, as AWS return the output in pages. - - var PaginationControl2 *string - var PaginationControl3 *string Items, err := sdk.CachedApiGatewayGetRestAPIs(m.APIGatewayClient, aws.ToString(m.Caller.Account), r) @@ -943,16 +939,20 @@ func (m *EndpointsModule) getAPIGatewayVIPsPerRegion(r string, wg *sync.WaitGrou return } - for { - GetDomainNames, err := m.APIGatewayClient.GetDomainNames( - context.TODO(), - &apigateway.GetDomainNamesInput{ - Position: PaginationControl2, - }, - func(o *apigateway.Options) { - o.Region = r - }, - ) + GetDomainNames, err := sdk.CachedApiGatewayGetDomainNames(m.APIGatewayClient, aws.ToString(m.Caller.Account), r) + + if err != nil { + if errors.As(err, &oe) { + m.Errors = append(m.Errors, (fmt.Sprintf(" Error: Region: %s, Service: %s, Operation: %s", r, oe.Service(), oe.Operation()))) + } + m.modLog.Error(err.Error()) + m.CommandCounter.Error++ + } + + for _, item := range GetDomainNames { + + domain := aws.ToString(item.DomainName) + GetBasePathMappings, err := sdk.CachedApiGatewayGetBasePathMappings(m.APIGatewayClient, aws.ToString(m.Caller.Account), r, item.DomainName) if err != nil { if errors.As(err, &oe) { @@ -963,82 +963,43 @@ func (m *EndpointsModule) getAPIGatewayVIPsPerRegion(r string, wg *sync.WaitGrou break } - for _, item := range GetDomainNames.Items { - - domain := aws.ToString(item.DomainName) - - for { - GetBasePathMappings, err := m.APIGatewayClient.GetBasePathMappings( - context.TODO(), - &apigateway.GetBasePathMappingsInput{ - DomainName: item.DomainName, - Position: PaginationControl3, - }, - func(o *apigateway.Options) { - o.Region = r - }, - ) - - if err != nil { - if errors.As(err, &oe) { - m.Errors = append(m.Errors, (fmt.Sprintf(" Error: Region: %s, Service: %s, Operation: %s", r, oe.Service(), oe.Operation()))) - } - m.modLog.Error(err.Error()) - m.CommandCounter.Error++ - break - } - - for _, mapping := range GetBasePathMappings.Items { - stage := aws.ToString(mapping.Stage) - basePath := aws.ToString(mapping.BasePath) - if basePath == "(none)" { - basePath = "" // Empty string since '/' is already prepended - } + for _, mapping := range GetBasePathMappings { + stage := aws.ToString(mapping.Stage) + basePath := aws.ToString(mapping.BasePath) + if basePath == "(none)" { + basePath = "" // Empty string since '/' is already prepended + } - for _, api := range Items { - if api.Id != nil && aws.ToString(api.Id) == aws.ToString(mapping.RestApiId) { - endpoints := m.getEndpointsPerAPIGateway(r, api) - for _, endpoint := range endpoints { - old := fmt.Sprintf("https://%s.execute-api.%s.amazonaws.com/%s/", aws.ToString(mapping.RestApiId), r, stage) - - if strings.HasPrefix(endpoint.Endpoint, old) { - var new string - if basePath == "" { - new = fmt.Sprintf("https://%s/", domain) - } else { - new = fmt.Sprintf("https://%s/%s/", domain, basePath) - } - endpoint.Endpoint = strings.Replace(endpoint.Endpoint, old, new, 1) - endpoint.Name = domain - dataReceiver <- endpoint - } + for _, api := range Items { + if api.Id != nil && aws.ToString(api.Id) == aws.ToString(mapping.RestApiId) { + endpoints := m.getEndpointsPerAPIGateway(r, api) + for _, endpoint := range endpoints { + old := fmt.Sprintf("https://%s.execute-api.%s.amazonaws.com/%s/", aws.ToString(mapping.RestApiId), r, stage) + + if strings.HasPrefix(endpoint.Endpoint, old) { + var new string + if basePath == "" { + new = fmt.Sprintf("https://%s/", domain) + } else { + new = fmt.Sprintf("https://%s/%s/", domain, basePath) } - break + endpoint.Endpoint = strings.Replace(endpoint.Endpoint, old, new, 1) + endpoint.Name = domain + dataReceiver <- endpoint } } - } - - if GetBasePathMappings.Position != nil { - PaginationControl3 = GetBasePathMappings.Position - } else { - PaginationControl3 = nil break } } } - if GetDomainNames.Position != nil { - PaginationControl2 = GetDomainNames.Position - } else { - PaginationControl2 = nil - break - } + } + } func (m *EndpointsModule) getEndpointsPerAPIGateway(r string, api apigatewayTypes.RestApi) []Endpoint { var endpoints []Endpoint - var PaginationControl2 *string awsService := "APIGateway" var public string @@ -1056,15 +1017,7 @@ func (m *EndpointsModule) getEndpointsPerAPIGateway(r string, api apigatewayType public = "True" } - GetStages, err := m.APIGatewayClient.GetStages( - context.TODO(), - &apigateway.GetStagesInput{ - RestApiId: &id, - }, - func(o *apigateway.Options) { - o.Region = r - }, - ) + GetStages, err := sdk.CachedApiGatewayGetStages(m.APIGatewayClient, aws.ToString(m.Caller.Account), r, id) if err != nil { if errors.As(err, &oe) { @@ -1074,54 +1027,38 @@ func (m *EndpointsModule) getEndpointsPerAPIGateway(r string, api apigatewayType m.CommandCounter.Error++ } - for { - GetResources, err := m.APIGatewayClient.GetResources( - context.TODO(), - &apigateway.GetResourcesInput{ - RestApiId: &id, - Position: PaginationControl2, - }, - func(o *apigateway.Options) { - o.Region = r - }, - ) + GetResources, err := sdk.CachedApiGatewayGetResources(m.APIGatewayClient, aws.ToString(m.Caller.Account), r, id) - if err != nil { - if errors.As(err, &oe) { - m.Errors = append(m.Errors, (fmt.Sprintf(" Error: Region: %s, Service: %s, Operation: %s", r, oe.Service(), oe.Operation()))) - } - m.modLog.Error(err.Error()) - m.CommandCounter.Error++ - break + if err != nil { + if errors.As(err, &oe) { + m.Errors = append(m.Errors, (fmt.Sprintf(" Error: Region: %s, Service: %s, Operation: %s", r, oe.Service(), oe.Operation()))) } + m.modLog.Error(err.Error()) + m.CommandCounter.Error++ - for _, stage := range GetStages.Item { - stageName := aws.ToString(stage.StageName) - for _, resource := range GetResources.Items { - if len(resource.ResourceMethods) != 0 { - path := aws.ToString(resource.Path) - - endpoint := fmt.Sprintf("%s/%s%s", raw_endpoint, stageName, path) - - endpoints = append(endpoints, Endpoint{ - AWSService: awsService, - Region: r, - Name: name, - Endpoint: endpoint, - Port: port, - Protocol: protocol, - Public: public, - }) - } + } + + for _, stage := range GetStages.Item { + stageName := aws.ToString(stage.StageName) + for _, resource := range GetResources { + if len(resource.ResourceMethods) != 0 { + path := aws.ToString(resource.Path) + + endpoint := fmt.Sprintf("%s/%s%s", raw_endpoint, stageName, path) + + endpoints = append(endpoints, Endpoint{ + AWSService: awsService, + Region: r, + Name: name, + Endpoint: endpoint, + Port: port, + Protocol: protocol, + Public: public, + }) } } - if GetResources.Position != nil { - PaginationControl2 = GetResources.Position - } else { - PaginationControl2 = nil - break - } } + return endpoints } @@ -1170,10 +1107,6 @@ func (m *EndpointsModule) getAPIGatewayv2VIPsPerRegion(r string, wg *sync.WaitGr // m.CommandCounter.Total++ m.CommandCounter.Pending-- m.CommandCounter.Executing++ - // "PaginationMarker" is a control variable used for output continuity, as AWS return the output in pages. - - var PaginationControl2 *string - var PaginationControl3 *string Items, err := sdk.CachedAPIGatewayv2GetAPIs(m.APIGatewayv2Client, aws.ToString(m.Caller.Account), r) @@ -1183,16 +1116,20 @@ func (m *EndpointsModule) getAPIGatewayv2VIPsPerRegion(r string, wg *sync.WaitGr return } - for { - GetDomainNames, err := m.APIGatewayv2Client.GetDomainNames( - context.TODO(), - &apigatewayv2.GetDomainNamesInput{ - NextToken: PaginationControl2, - }, - func(o *apigatewayv2.Options) { - o.Region = r - }, - ) + GetDomainNames, err := sdk.CachedAPIGatewayv2GetDomainNames(m.APIGatewayv2Client, aws.ToString(m.Caller.Account), r) + + if err != nil { + if errors.As(err, &oe) { + m.Errors = append(m.Errors, (fmt.Sprintf(" Error: Region: %s, Service: %s, Operation: %s", r, oe.Service(), oe.Operation()))) + } + m.modLog.Error(err.Error()) + m.CommandCounter.Error++ + } + + for _, item := range GetDomainNames { + + domain := aws.ToString(item.DomainName) + GetApiMappings, err := sdk.CachedAPIGatewayv2GetApiMappings(m.APIGatewayv2Client, aws.ToString(m.Caller.Account), r, domain) if err != nil { if errors.As(err, &oe) { @@ -1203,87 +1140,46 @@ func (m *EndpointsModule) getAPIGatewayv2VIPsPerRegion(r string, wg *sync.WaitGr break } - for _, item := range GetDomainNames.Items { - - domain := aws.ToString(item.DomainName) - - for { - GetApiMappings, err := m.APIGatewayv2Client.GetApiMappings( - context.TODO(), - &apigatewayv2.GetApiMappingsInput{ - DomainName: item.DomainName, - NextToken: PaginationControl3, - }, - func(o *apigatewayv2.Options) { - o.Region = r - }, - ) - - if err != nil { - if errors.As(err, &oe) { - m.Errors = append(m.Errors, (fmt.Sprintf(" Error: Region: %s, Service: %s, Operation: %s", r, oe.Service(), oe.Operation()))) - } - m.modLog.Error(err.Error()) - m.CommandCounter.Error++ - break - } - - for _, mapping := range GetApiMappings.Items { - stage := aws.ToString(mapping.Stage) - if stage == "$default" { - stage = "" - } - path := aws.ToString(mapping.ApiMappingKey) - - for _, api := range Items { - if api.ApiId != nil && aws.ToString(api.ApiId) == aws.ToString(mapping.ApiId) { - endpoints := m.getEndpointsPerAPIGatewayv2(r, api) - for _, endpoint := range endpoints { - var old string - if stage == "" { - old = fmt.Sprintf("https://%s.execute-api.%s.amazonaws.com/", aws.ToString(mapping.ApiId), r) - } else { - old = fmt.Sprintf("https://%s.execute-api.%s.amazonaws.com/%s/", aws.ToString(mapping.ApiId), r, stage) - } - if strings.HasPrefix(endpoint.Endpoint, old) { - var new string - if path == "" { - new = fmt.Sprintf("https://%s/", domain) - } else { - new = fmt.Sprintf("https://%s/%s/", domain, path) - } - endpoint.Endpoint = strings.Replace(endpoint.Endpoint, old, new, 1) - endpoint.Name = domain - dataReceiver <- endpoint - } + for _, mapping := range GetApiMappings { + stage := aws.ToString(mapping.Stage) + if stage == "$default" { + stage = "" + } + path := aws.ToString(mapping.ApiMappingKey) + + for _, api := range Items { + if api.ApiId != nil && aws.ToString(api.ApiId) == aws.ToString(mapping.ApiId) { + endpoints := m.getEndpointsPerAPIGatewayv2(r, api) + for _, endpoint := range endpoints { + var old string + if stage == "" { + old = fmt.Sprintf("https://%s.execute-api.%s.amazonaws.com/", aws.ToString(mapping.ApiId), r) + } else { + old = fmt.Sprintf("https://%s.execute-api.%s.amazonaws.com/%s/", aws.ToString(mapping.ApiId), r, stage) + } + if strings.HasPrefix(endpoint.Endpoint, old) { + var new string + if path == "" { + new = fmt.Sprintf("https://%s/", domain) + } else { + new = fmt.Sprintf("https://%s/%s/", domain, path) } - break + endpoint.Endpoint = strings.Replace(endpoint.Endpoint, old, new, 1) + endpoint.Name = domain + dataReceiver <- endpoint } } - } - - if GetApiMappings.NextToken != nil { - PaginationControl3 = GetApiMappings.NextToken - } else { - PaginationControl3 = nil break } } } - if GetDomainNames.NextToken != nil { - PaginationControl2 = GetDomainNames.NextToken - } else { - PaginationControl2 = nil - break - } + } + } func (m *EndpointsModule) getEndpointsPerAPIGatewayv2(r string, api apigatewayV2Types.Api) []Endpoint { var endpoints []Endpoint - - var PaginationControl2 *string - var PaginationControl3 *string awsService := "APIGatewayv2" var public string @@ -1295,98 +1191,61 @@ func (m *EndpointsModule) getEndpointsPerAPIGatewayv2(r string, api apigatewayV2 protocol := "https" var stages []string - for { - GetStages, err := m.APIGatewayv2Client.GetStages( - context.TODO(), - &apigatewayv2.GetStagesInput{ - ApiId: &id, - NextToken: PaginationControl2, - }, - func(o *apigatewayv2.Options) { - o.Region = r - }, - ) + GetStages, err := sdk.CachedAPIGatewayv2GetStages(m.APIGatewayv2Client, aws.ToString(m.Caller.Account), r, id) - if err != nil { - if errors.As(err, &oe) { - m.Errors = append(m.Errors, (fmt.Sprintf(" Error: Region: %s, Service: %s, Operation: %s", r, oe.Service(), oe.Operation()))) - } - m.modLog.Error(err.Error()) - m.CommandCounter.Error++ - break - } - - for _, stage := range GetStages.Items { - s := aws.ToString(stage.StageName) - if s == "$default" { - s = "" - } - stages = append(stages, s) + if err != nil { + if errors.As(err, &oe) { + m.Errors = append(m.Errors, (fmt.Sprintf(" Error: Region: %s, Service: %s, Operation: %s", r, oe.Service(), oe.Operation()))) } + m.modLog.Error(err.Error()) + m.CommandCounter.Error++ + } - if GetStages.NextToken != nil { - PaginationControl2 = GetStages.NextToken - } else { - PaginationControl2 = nil - break + for _, stage := range GetStages { + s := aws.ToString(stage.StageName) + if s == "$default" { + s = "" } + stages = append(stages, s) } - for { - GetRoutes, err := m.APIGatewayv2Client.GetRoutes( - context.TODO(), - &apigatewayv2.GetRoutesInput{ - ApiId: &id, - NextToken: PaginationControl3, - }, - func(o *apigatewayv2.Options) { - o.Region = r - }, - ) + GetRoutes, err := sdk.CachedAPIGatewayv2GetRoutes(m.APIGatewayv2Client, aws.ToString(m.Caller.Account), r, id) - if err != nil { - if errors.As(err, &oe) { - m.Errors = append(m.Errors, (fmt.Sprintf(" Error: Region: %s, Service: %s, Operation: %s", r, oe.Service(), oe.Operation()))) - } - m.modLog.Error(err.Error()) - m.CommandCounter.Error++ - continue + if err != nil { + if errors.As(err, &oe) { + m.Errors = append(m.Errors, (fmt.Sprintf(" Error: Region: %s, Service: %s, Operation: %s", r, oe.Service(), oe.Operation()))) } + m.modLog.Error(err.Error()) + m.CommandCounter.Error++ + } - for _, stage := range stages { - for _, route := range GetRoutes.Items { - routeKey := route.RouteKey - var path string - if len(strings.Fields(*routeKey)) == 2 { - path = strings.Fields(*routeKey)[1] - } - var endpoint string - if stage == "" { - endpoint = fmt.Sprintf("%s%s", raw_endpoint, path) - } else { - endpoint = fmt.Sprintf("%s/%s%s", raw_endpoint, stage, path) - } - public = "True" - - endpoints = append(endpoints, Endpoint{ - AWSService: awsService, - Region: r, - Name: name, - Endpoint: endpoint, - Port: port, - Protocol: protocol, - Public: public, - }) + for _, stage := range stages { + for _, route := range GetRoutes { + routeKey := route.RouteKey + var path string + if len(strings.Fields(*routeKey)) == 2 { + path = strings.Fields(*routeKey)[1] } - } - if GetRoutes.NextToken != nil { - PaginationControl3 = GetRoutes.NextToken - } else { - PaginationControl3 = nil - break - } + var endpoint string + if stage == "" { + endpoint = fmt.Sprintf("%s%s", raw_endpoint, path) + } else { + endpoint = fmt.Sprintf("%s/%s%s", raw_endpoint, stage, path) + } + public = "True" + endpoints = append(endpoints, Endpoint{ + AWSService: awsService, + Region: r, + Name: name, + Endpoint: endpoint, + Port: port, + Protocol: protocol, + Public: public, + }) + } } + return endpoints }