Skip to content

Commit

Permalink
Merge pull request hashicorp#34087 from hashicorp/td-framework-goodies
Browse files Browse the repository at this point in the history
Additional Framework helpers
  • Loading branch information
ewbankkit authored Oct 25, 2023
2 parents 1201526 + ae73c96 commit 0c6b499
Show file tree
Hide file tree
Showing 18 changed files with 192 additions and 38 deletions.
3 changes: 2 additions & 1 deletion internal/framework/flex/bool.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/hashicorp/terraform-plugin-framework/diag"
"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/hashicorp/terraform-plugin-framework/types/basetypes"
"github.com/hashicorp/terraform-provider-aws/internal/errs/fwdiag"
)

// BoolFromFramework converts a Framework Bool value to a bool pointer.
// A null Bool is converted to a nil bool pointer.
func BoolFromFramework(ctx context.Context, v types.Bool) *bool {
func BoolFromFramework(ctx context.Context, v basetypes.BoolValuable) *bool {
var output *bool

panicOnError(Expand(ctx, v, &output))
Expand Down
3 changes: 2 additions & 1 deletion internal/framework/flex/int.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ import (

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/hashicorp/terraform-plugin-framework/types/basetypes"
)

// Int64FromFramework converts a Framework Int64 value to an int64 pointer.
// A null Int64 is converted to a nil int64 pointer.
func Int64FromFramework(ctx context.Context, v types.Int64) *int64 {
func Int64FromFramework(ctx context.Context, v basetypes.Int64Valuable) *int64 {
var output *int64

panicOnError(Expand(ctx, v, &output))
Expand Down
5 changes: 3 additions & 2 deletions internal/framework/flex/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/hashicorp/terraform-plugin-framework/attr"
"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/hashicorp/terraform-plugin-framework/types/basetypes"
"github.com/hashicorp/terraform-provider-aws/internal/errs/fwdiag"
"github.com/hashicorp/terraform-provider-aws/internal/slices"
)

func ExpandFrameworkStringList(ctx context.Context, v types.List) []*string {
func ExpandFrameworkStringList(ctx context.Context, v basetypes.ListValuable) []*string {
var output []*string

panicOnError(Expand(ctx, v, &output))

return output
}

func ExpandFrameworkStringValueList(ctx context.Context, v types.List) []string {
func ExpandFrameworkStringValueList(ctx context.Context, v basetypes.ListValuable) []string {
var output []string

panicOnError(Expand(ctx, v, &output))
Expand Down
5 changes: 3 additions & 2 deletions internal/framework/flex/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@ import (

"github.com/hashicorp/terraform-plugin-framework/attr"
"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/hashicorp/terraform-plugin-framework/types/basetypes"
)

func ExpandFrameworkStringMap(ctx context.Context, v types.Map) map[string]*string {
func ExpandFrameworkStringMap(ctx context.Context, v basetypes.MapValuable) map[string]*string {
var output map[string]*string

panicOnError(Expand(ctx, v, &output))

return output
}

func ExpandFrameworkStringValueMap(ctx context.Context, v types.Map) map[string]string {
func ExpandFrameworkStringValueMap(ctx context.Context, v basetypes.MapValuable) map[string]string {
var output map[string]string

panicOnError(Expand(ctx, v, &output))
Expand Down
5 changes: 3 additions & 2 deletions internal/framework/flex/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/hashicorp/terraform-plugin-framework/attr"
"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/hashicorp/terraform-plugin-framework/types/basetypes"
)

func ExpandFrameworkStringSet(ctx context.Context, v types.Set) []*string {
func ExpandFrameworkStringSet(ctx context.Context, v basetypes.SetValuable) []*string {
var output []*string

panicOnError(Expand(ctx, v, &output))

return output
}

func ExpandFrameworkStringValueSet(ctx context.Context, v types.Set) Set[string] {
func ExpandFrameworkStringValueSet(ctx context.Context, v basetypes.SetValuable) Set[string] {
var output []string

panicOnError(Expand(ctx, v, &output))
Expand Down
22 changes: 11 additions & 11 deletions internal/framework/flex/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/hashicorp/terraform-plugin-framework/diag"
"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/hashicorp/terraform-plugin-framework/types/basetypes"
fwtypes "github.com/hashicorp/terraform-provider-aws/internal/framework/types"
)

// StringFromFramework converts a Framework String value to a string pointer.
// A null String is converted to a nil string pointer.
func StringFromFramework(ctx context.Context, v types.String) *string {
func StringFromFramework(ctx context.Context, v basetypes.StringValuable) *string {
var output *string

panicOnError(Expand(ctx, v, &output))
Expand All @@ -24,7 +25,7 @@ func StringFromFramework(ctx context.Context, v types.String) *string {

// StringFromFramework converts a single Framework String value to a string pointer slice.
// A null String is converted to a nil slice.
func StringSliceFromFramework(ctx context.Context, v types.String) []*string {
func StringSliceFromFramework(ctx context.Context, v basetypes.StringValuable) []*string {
if v.IsNull() || v.IsUnknown() {
return nil
}
Expand Down Expand Up @@ -68,18 +69,17 @@ func StringToFrameworkLegacy(_ context.Context, v *string) types.String {
return types.StringValue(aws.ToString(v))
}

func ARNStringFromFramework(ctx context.Context, v fwtypes.ARN) *string {
var output *string

panicOnError(Expand(ctx, v, &output))

return output
}

// StringToFrameworkARN converts a string pointer to a Framework custom ARN value.
// A nil string pointer is converted to a null ARN.
// If diags is nil, any errors cause a panic.
func StringToFrameworkARN(ctx context.Context, v *string, diags *diag.Diagnostics) fwtypes.ARN {
var output fwtypes.ARN

diags.Append(Flatten(ctx, v, &output)...)
if diags == nil {
panicOnError(Flatten(ctx, v, &output))
} else {
diags.Append(Flatten(ctx, v, &output)...)
}

return output
}
Expand Down
2 changes: 1 addition & 1 deletion internal/framework/flex/string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ func TestARNStringFromFramework(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()

got := flex.ARNStringFromFramework(context.Background(), test.input)
got := flex.StringFromFramework(context.Background(), test.input)

if diff := cmp.Diff(got, test.expected); diff != "" {
t.Errorf("unexpected diff (+wanted, -got): %s", diff)
Expand Down
3 changes: 2 additions & 1 deletion internal/framework/types/arn.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ const (
)

var (
_ xattr.TypeWithValidate = ARNType
_ xattr.TypeWithValidate = ARNType
_ basetypes.StringValuable = ARN{}
)

func (t arnType) TerraformType(_ context.Context) tftypes.Type {
Expand Down
3 changes: 2 additions & 1 deletion internal/framework/types/cidr_block.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ const (
)

var (
_ xattr.TypeWithValidate = CIDRBlockType
_ xattr.TypeWithValidate = CIDRBlockType
_ basetypes.StringValuable = CIDRBlock{}
)

func (t cidrBlockType) TerraformType(_ context.Context) tftypes.Type {
Expand Down
3 changes: 2 additions & 1 deletion internal/framework/types/duration.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ const (
)

var (
_ xattr.TypeWithValidate = DurationType
_ xattr.TypeWithValidate = DurationType
_ basetypes.StringValuable = Duration{}
)

func (d durationType) TerraformType(_ context.Context) tftypes.Type {
Expand Down
3 changes: 2 additions & 1 deletion internal/framework/types/regexp.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ var (
)

var (
_ xattr.TypeWithValidate = RegexpType
_ xattr.TypeWithValidate = RegexpType
_ basetypes.StringValuable = Regexp{}
)

func (t regexpType) TerraformType(_ context.Context) tftypes.Type {
Expand Down
52 changes: 52 additions & 0 deletions internal/framework/validators/aws_account_id.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package validators

import (
"context"

"github.com/YakDriver/regexache"
"github.com/hashicorp/terraform-plugin-framework-validators/helpers/validatordiag"
"github.com/hashicorp/terraform-plugin-framework/schema/validator"
)

// awsAccountIDValidator validates that a string Attribute's value is a valid AWS account ID.
type awsAccountIDValidator struct{}

// Description describes the validation in plain text formatting.
func (validator awsAccountIDValidator) Description(_ context.Context) string {
return "value must be a valid AWS account ID"
}

// MarkdownDescription describes the validation in Markdown formatting.
func (validator awsAccountIDValidator) MarkdownDescription(ctx context.Context) string {
return validator.Description(ctx)
}

// ValidateString performs the validation.
func (validator awsAccountIDValidator) ValidateString(ctx context.Context, request validator.StringRequest, response *validator.StringResponse) {
if request.ConfigValue.IsNull() || request.ConfigValue.IsUnknown() {
return
}

// https://docs.aws.amazon.com/accounts/latest/reference/manage-acct-identifiers.html.
if !regexache.MustCompile(`^\d{12}$`).MatchString(request.ConfigValue.ValueString()) {
response.Diagnostics.Append(validatordiag.InvalidAttributeValueDiagnostic(
request.Path,
validator.Description(ctx),
request.ConfigValue.ValueString(),
))
return
}
}

// AWSAccountID returns a string validator which ensures that any configured
// attribute value:
//
// - Is a string, which represents a valid AWS account ID.
//
// Null (unconfigured) and unknown (known after apply) values are skipped.
func AWSAccountID() validator.String { // nosemgrep:ci.aws-in-func-name
return awsAccountIDValidator{}
}
87 changes: 87 additions & 0 deletions internal/framework/validators/aws_account_id_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package validators_test

import (
"context"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/hashicorp/terraform-plugin-framework/diag"
"github.com/hashicorp/terraform-plugin-framework/path"
"github.com/hashicorp/terraform-plugin-framework/schema/validator"
"github.com/hashicorp/terraform-plugin-framework/types"
fwvalidators "github.com/hashicorp/terraform-provider-aws/internal/framework/validators"
)

func TestAWSAccountIDValidator(t *testing.T) { // nosemgrep:ci.aws-in-func-name
t.Parallel()

type testCase struct {
val types.String
expectedDiagnostics diag.Diagnostics
}
tests := map[string]testCase{
"unknown String": {
val: types.StringUnknown(),
},
"null String": {
val: types.StringNull(),
},
"invalid String": {
val: types.StringValue("test-value"),
expectedDiagnostics: diag.Diagnostics{
diag.NewAttributeErrorDiagnostic(
path.Root("test"),
"Invalid Attribute Value",
`Attribute test value must be a valid AWS account ID, got: test-value`,
),
},
},
"valid AWS account ID": {
val: types.StringValue("123456789012"),
},
"too long AWS account ID": {
val: types.StringValue("1234567890123"),
expectedDiagnostics: diag.Diagnostics{
diag.NewAttributeErrorDiagnostic(
path.Root("test"),
"Invalid Attribute Value",
`Attribute test value must be a valid AWS account ID, got: 1234567890123`,
),
},
},
"too short AWS account ID": {
val: types.StringValue("12345678901"),
expectedDiagnostics: diag.Diagnostics{
diag.NewAttributeErrorDiagnostic(
path.Root("test"),
"Invalid Attribute Value",
`Attribute test value must be a valid AWS account ID, got: 12345678901`,
),
},
},
}

for name, test := range tests {
name, test := name, test
t.Run(name, func(t *testing.T) {
t.Parallel()

ctx := context.Background()

request := validator.StringRequest{
Path: path.Root("test"),
PathExpression: path.MatchRoot("test"),
ConfigValue: test.val,
}
response := validator.StringResponse{}
fwvalidators.AWSAccountID().ValidateString(ctx, request, &response)

if diff := cmp.Diff(response.Diagnostics, test.expectedDiagnostics); diff != "" {
t.Errorf("unexpected diagnostics difference: %s", diff)
}
})
}
}
6 changes: 3 additions & 3 deletions internal/service/batch/job_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func (r *resourceJobQueue) Create(ctx context.Context, request resource.CreateRe
}

if !data.SchedulingPolicyARN.IsNull() {
input.SchedulingPolicyArn = flex.ARNStringFromFramework(ctx, data.SchedulingPolicyARN)
input.SchedulingPolicyArn = flex.StringFromFramework(ctx, data.SchedulingPolicyARN)
}

output, err := conn.CreateJobQueueWithContext(ctx, &input)
Expand Down Expand Up @@ -229,13 +229,13 @@ func (r *resourceJobQueue) Update(ctx context.Context, request resource.UpdateRe
}

if !state.SchedulingPolicyARN.IsNull() {
input.SchedulingPolicyArn = flex.ARNStringFromFramework(ctx, state.SchedulingPolicyARN)
input.SchedulingPolicyArn = flex.StringFromFramework(ctx, state.SchedulingPolicyARN)
update = true
}

if !plan.SchedulingPolicyARN.Equal(state.SchedulingPolicyARN) {
if !plan.SchedulingPolicyARN.IsNull() || !plan.SchedulingPolicyARN.IsUnknown() {
input.SchedulingPolicyArn = flex.ARNStringFromFramework(ctx, plan.SchedulingPolicyARN)
input.SchedulingPolicyArn = flex.StringFromFramework(ctx, plan.SchedulingPolicyARN)

update = true
} else {
Expand Down
4 changes: 2 additions & 2 deletions internal/service/cognitoidp/user_pool_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -706,10 +706,10 @@ func (ac *analyticsConfiguration) expand(ctx context.Context) *cognitoidentitypr
return nil
}
result := &cognitoidentityprovider.AnalyticsConfigurationType{
ApplicationArn: flex.ARNStringFromFramework(ctx, ac.ApplicationARN),
ApplicationArn: flex.StringFromFramework(ctx, ac.ApplicationARN),
ApplicationId: flex.StringFromFramework(ctx, ac.ApplicationID),
ExternalId: flex.StringFromFramework(ctx, ac.ExternalID),
RoleArn: flex.ARNStringFromFramework(ctx, ac.RoleARN),
RoleArn: flex.StringFromFramework(ctx, ac.RoleARN),
UserDataShared: flex.BoolFromFramework(ctx, ac.UserDataShared),
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package globalaccelerator
import (
"context"

"github.com/aws/aws-sdk-go-v2/aws/arn"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/globalaccelerator"
"github.com/hashicorp/terraform-plugin-framework/attr"
Expand Down Expand Up @@ -150,11 +149,7 @@ func (d *dataSourceAccelerator) Read(ctx context.Context, request datasource.Rea

accelerator := results[0]
acceleratorARN := aws.StringValue(accelerator.AcceleratorArn)
if v, err := arn.Parse(acceleratorARN); err != nil {
response.Diagnostics.AddError("parsing ARN", err.Error())
} else {
data.ARN = fwtypes.ARNValue(v)
}
data.ARN = flex.StringToFrameworkARN(ctx, accelerator.AcceleratorArn, nil)
data.DnsName = flex.StringToFrameworkLegacy(ctx, accelerator.DnsName)
data.DualStackDNSName = flex.StringToFrameworkLegacy(ctx, accelerator.DualStackDnsName)
data.Enabled = flex.BoolToFrameworkLegacy(ctx, accelerator.Enabled)
Expand Down
Loading

0 comments on commit 0c6b499

Please sign in to comment.