diff --git a/cedar.go b/cedar.go index 4c72106e..6340a9df 100644 --- a/cedar.go +++ b/cedar.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" + "github.com/cedar-policy/cedar-go/types" "github.com/cedar-policy/cedar-go/x/exp/parser" "golang.org/x/exp/maps" "golang.org/x/exp/slices" @@ -94,13 +95,13 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { // An Entities is a collection of all the Entities that are needed to evaluate // authorization requests. The key is an EntityUID which uniquely identifies // the Entity (it must be the same as the UID within the Entity itself.) -type Entities map[EntityUID]Entity +type Entities map[types.EntityUID]Entity // An Entity defines the parents and attributes for an EntityUID. type Entity struct { - UID EntityUID `json:"uid"` - Parents []EntityUID `json:"parents,omitempty"` - Attributes Record `json:"attrs"` + UID types.EntityUID `json:"uid"` + Parents []types.EntityUID `json:"parents,omitempty"` + Attributes types.Record `json:"attrs"` } func (e Entities) MarshalJSON() ([]byte, error) { @@ -188,10 +189,10 @@ type Reason struct { // A Request is the Principal, Action, Resource, and Context portion of an // authorization request. type Request struct { - Principal EntityUID `json:"principal"` - Action EntityUID `json:"action"` - Resource EntityUID `json:"resource"` - Context Record `json:"context"` + Principal types.EntityUID `json:"principal"` + Action types.EntityUID `json:"action"` + Resource types.EntityUID `json:"resource"` + Context types.Record `json:"context"` } // IsAuthorized uses the combination of the PolicySet and Entities to determine @@ -220,7 +221,7 @@ func (p PolicySet) IsAuthorized(entities Entities, req Request) (Decision, Diagn diag.Errors = append(diag.Errors, Error{Policy: n, Position: po.Position, Message: err.Error()}) continue } - vb, err := valueToBool(v) + vb, err := types.ValueToBool(v) if err != nil { // should never happen, maybe remove this case diag.Errors = append(diag.Errors, Error{Policy: n, Position: po.Position, Message: err.Error()}) diff --git a/cedar_test.go b/cedar_test.go index 1f34b863..774e4233 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -1,27 +1,31 @@ package cedar import ( + "encoding/json" "net/netip" "testing" + + "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/types" ) func TestEntityIsZero(t *testing.T) { t.Parallel() tests := []struct { name string - uid EntityUID + uid types.EntityUID want bool }{ - {"empty", EntityUID{}, true}, - {"empty-type", NewEntityUID("one", ""), false}, - {"empty-id", NewEntityUID("", "one"), false}, - {"not-empty", NewEntityUID("one", "two"), false}, + {"empty", types.EntityUID{}, true}, + {"empty-type", types.NewEntityUID("one", ""), false}, + {"empty-id", types.NewEntityUID("", "one"), false}, + {"not-empty", types.NewEntityUID("one", "two"), false}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - testutilEquals(t, tt.uid.IsZero(), tt.want) + testutil.Equals(t, tt.uid.IsZero(), tt.want) }) } } @@ -31,18 +35,18 @@ func TestNewPolicySet(t *testing.T) { t.Run("err-in-tokenize", func(t *testing.T) { t.Parallel() _, err := NewPolicySet("policy.cedar", []byte(`"`)) - testutilError(t, err) + testutil.Error(t, err) }) t.Run("err-in-parse", func(t *testing.T) { t.Parallel() _, err := NewPolicySet("policy.cedar", []byte(`err`)) - testutilError(t, err) + testutil.Error(t, err) }) t.Run("annotations", func(t *testing.T) { t.Parallel() ps, err := NewPolicySet("policy.cedar", []byte(`@key("value") permit (principal, action, resource);`)) - testutilOK(t, err) - testutilEquals(t, ps[0].Annotations, Annotations{"key": "value"}) + testutil.OK(t, err) + testutil.Equals(t, ps[0].Annotations, Annotations{"key": "value"}) }) } @@ -52,8 +56,8 @@ func TestIsAuthorized(t *testing.T) { Name string Policy string Entities Entities - Principal, Action, Resource EntityUID - Context Record + Principal, Action, Resource types.EntityUID + Context types.Record Want Decision DiagErr int }{ @@ -61,10 +65,10 @@ func TestIsAuthorized(t *testing.T) { Name: "simple-permit", Policy: `permit(principal,action,resource);`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -72,10 +76,10 @@ func TestIsAuthorized(t *testing.T) { Name: "simple-forbid", Policy: `forbid(principal,action,resource);`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 0, }, @@ -83,10 +87,10 @@ func TestIsAuthorized(t *testing.T) { Name: "no-permit", Policy: `permit(principal,action,resource in asdf::"1234");`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 0, }, @@ -94,10 +98,10 @@ func TestIsAuthorized(t *testing.T) { Name: "error-in-policy", Policy: `permit(principal,action,resource) when { resource in "foo" };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -107,10 +111,10 @@ func TestIsAuthorized(t *testing.T) { permit(principal,action,resource); `, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 1, }, @@ -118,10 +122,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-requires-context-success", Policy: `permit(principal,action,resource) when { context.x == 42 };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{"x": Long(42)}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{"x": types.Long(42)}, Want: true, DiagErr: 0, }, @@ -129,10 +133,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-requires-context-fail", Policy: `permit(principal,action,resource) when { context.x == 42 };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{"x": Long(43)}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{"x": types.Long(43)}, Want: false, DiagErr: 0, }, @@ -141,14 +145,14 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { principal.x == 42 };`, Entities: entitiesFromSlice([]Entity{ { - UID: EntityUID{"coder", "cuzco"}, - Attributes: Record{"x": Long(42)}, + UID: types.EntityUID{"coder", "cuzco"}, + Attributes: types.Record{"x": types.Long(42)}, }, }), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -157,14 +161,14 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { principal.x == 42 };`, Entities: entitiesFromSlice([]Entity{ { - UID: EntityUID{"coder", "cuzco"}, - Attributes: Record{"x": Long(43)}, + UID: types.EntityUID{"coder", "cuzco"}, + Attributes: types.Record{"x": types.Long(43)}, }, }), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 0, }, @@ -173,14 +177,14 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { principal in parent::"bob" };`, Entities: entitiesFromSlice([]Entity{ { - UID: EntityUID{"coder", "cuzco"}, - Parents: []EntityUID{{"parent", "bob"}}, + UID: types.EntityUID{"coder", "cuzco"}, + Parents: []types.EntityUID{{"parent", "bob"}}, }, }), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -188,10 +192,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-principal-equals", Policy: `permit(principal == coder::"cuzco",action,resource);`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -200,14 +204,14 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal in team::"osiris",action,resource);`, Entities: entitiesFromSlice([]Entity{ { - UID: EntityUID{"coder", "cuzco"}, - Parents: []EntityUID{{"team", "osiris"}}, + UID: types.EntityUID{"coder", "cuzco"}, + Parents: []types.EntityUID{{"team", "osiris"}}, }, }), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -215,10 +219,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-action-equals", Policy: `permit(principal,action == table::"drop",resource);`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -227,14 +231,14 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action in scary::"stuff",resource);`, Entities: entitiesFromSlice([]Entity{ { - UID: EntityUID{"table", "drop"}, - Parents: []EntityUID{{"scary", "stuff"}}, + UID: types.EntityUID{"table", "drop"}, + Parents: []types.EntityUID{{"scary", "stuff"}}, }, }), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -243,14 +247,14 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action in [scary::"stuff"],resource);`, Entities: entitiesFromSlice([]Entity{ { - UID: EntityUID{"table", "drop"}, - Parents: []EntityUID{{"scary", "stuff"}}, + UID: types.EntityUID{"table", "drop"}, + Parents: []types.EntityUID{{"scary", "stuff"}}, }, }), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -258,10 +262,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-resource-equals", Policy: `permit(principal,action,resource == table::"whatever");`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -269,10 +273,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-unless", Policy: `permit(principal,action,resource) unless { false };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -280,10 +284,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-if", Policy: `permit(principal,action,resource) when { (if true then true else true) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -291,10 +295,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-or", Policy: `permit(principal,action,resource) when { (true || false) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -302,10 +306,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-and", Policy: `permit(principal,action,resource) when { (true && true) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -313,10 +317,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-relations", Policy: `permit(principal,action,resource) when { (1<2) && (1<=1) && (2>1) && (1>=1) && (1!=2) && (1==1)};`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -324,10 +328,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-relations-in", Policy: `permit(principal,action,resource) when { principal in principal };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -336,14 +340,14 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { principal has name };`, Entities: entitiesFromSlice([]Entity{ { - UID: EntityUID{"coder", "cuzco"}, - Attributes: Record{"name": String("bob")}, + UID: types.EntityUID{"coder", "cuzco"}, + Attributes: types.Record{"name": types.String("bob")}, }, }), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -351,10 +355,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-add-sub", Policy: `permit(principal,action,resource) when { 40+3-1==42 };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -362,10 +366,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-mul", Policy: `permit(principal,action,resource) when { 6*7==42 };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -373,10 +377,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-negate", Policy: `permit(principal,action,resource) when { -42==-42 };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -384,10 +388,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-not", Policy: `permit(principal,action,resource) when { !(1+1==42) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -395,10 +399,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -406,10 +410,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-record", Policy: `permit(principal,action,resource) when { {name:"bob"} has name };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -417,10 +421,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-action", Policy: `permit(principal,action,resource) when { action in action };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -428,10 +432,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-contains-ok", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -439,10 +443,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-contains-error", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2,3) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -450,10 +454,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-containsAll-ok", Policy: `permit(principal,action,resource) when { [1,2,3].containsAll([2,3]) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -461,10 +465,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-containsAll-error", Policy: `permit(principal,action,resource) when { [1,2,3].containsAll(2,3) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -472,10 +476,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-containsAny-ok", Policy: `permit(principal,action,resource) when { [1,2,3].containsAny([2,5]) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -483,10 +487,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-containsAny-error", Policy: `permit(principal,action,resource) when { [1,2,3].containsAny(2,5) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -494,10 +498,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-record-attr", Policy: `permit(principal,action,resource) when { {name:"bob"}["name"] == "bob" };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -505,10 +509,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-unknown-method", Policy: `permit(principal,action,resource) when { [1,2,3].shuffle() };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -516,10 +520,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-like", Policy: `permit(principal,action,resource) when { "bananas" like "*nan*" };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -527,10 +531,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-unknown-ext-fun", Policy: `permit(principal,action,resource) when { fooBar("10") };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -542,10 +546,10 @@ func TestIsAuthorized(t *testing.T) { decimal("10.0").greaterThan(decimal("9.0")) && decimal("10.0").greaterThanOrEqual(decimal("9.0")) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -553,10 +557,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-decimal-fun-wrong-arity", Policy: `permit(principal,action,resource) when { decimal(1, 2) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -569,10 +573,10 @@ func TestIsAuthorized(t *testing.T) { ip("224.1.2.3").isMulticast() && ip("127.0.0.1").isInRange(ip("127.0.0.0/16"))};`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -580,10 +584,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-ip-fun-wrong-arity", Policy: `permit(principal,action,resource) when { ip() };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -591,10 +595,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-isIpv4-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv4(true) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -602,10 +606,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-isIpv6-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv6(true) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -613,10 +617,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-isLoopback-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isLoopback(true) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -624,10 +628,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-isMulticast-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isMulticast(true) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -635,10 +639,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-isInRange-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isInRange() };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -646,7 +650,7 @@ func TestIsAuthorized(t *testing.T) { Name: "negative-unary-op", Policy: `permit(principal,action,resource) when { -context.value > 0 };`, Entities: entitiesFromSlice(nil), - Context: Record{"value": Long(-42)}, + Context: types.Record{"value": types.Long(-42)}, Want: true, DiagErr: 0, }, @@ -654,10 +658,10 @@ func TestIsAuthorized(t *testing.T) { Name: "principal-is", Policy: `permit(principal is Actor,action,resource);`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"Actor", "cuzco"}, - Action: EntityUID{"Action", "drop"}, - Resource: EntityUID{"Resource", "table"}, - Context: Record{}, + Principal: types.EntityUID{"Actor", "cuzco"}, + Action: types.EntityUID{"Action", "drop"}, + Resource: types.EntityUID{"Resource", "table"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -665,10 +669,10 @@ func TestIsAuthorized(t *testing.T) { Name: "principal-is-in", Policy: `permit(principal is Actor in Actor::"cuzco",action,resource);`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"Actor", "cuzco"}, - Action: EntityUID{"Action", "drop"}, - Resource: EntityUID{"Resource", "table"}, - Context: Record{}, + Principal: types.EntityUID{"Actor", "cuzco"}, + Action: types.EntityUID{"Action", "drop"}, + Resource: types.EntityUID{"Resource", "table"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -676,10 +680,10 @@ func TestIsAuthorized(t *testing.T) { Name: "resource-is", Policy: `permit(principal,action,resource is Resource);`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"Actor", "cuzco"}, - Action: EntityUID{"Action", "drop"}, - Resource: EntityUID{"Resource", "table"}, - Context: Record{}, + Principal: types.EntityUID{"Actor", "cuzco"}, + Action: types.EntityUID{"Action", "drop"}, + Resource: types.EntityUID{"Resource", "table"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -687,10 +691,10 @@ func TestIsAuthorized(t *testing.T) { Name: "resource-is-in", Policy: `permit(principal,action,resource is Resource in Resource::"table");`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"Actor", "cuzco"}, - Action: EntityUID{"Action", "drop"}, - Resource: EntityUID{"Resource", "table"}, - Context: Record{}, + Principal: types.EntityUID{"Actor", "cuzco"}, + Action: types.EntityUID{"Action", "drop"}, + Resource: types.EntityUID{"Resource", "table"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -698,10 +702,10 @@ func TestIsAuthorized(t *testing.T) { Name: "when-is", Policy: `permit(principal,action,resource) when { resource is Resource };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"Actor", "cuzco"}, - Action: EntityUID{"Action", "drop"}, - Resource: EntityUID{"Resource", "table"}, - Context: Record{}, + Principal: types.EntityUID{"Actor", "cuzco"}, + Action: types.EntityUID{"Action", "drop"}, + Resource: types.EntityUID{"Resource", "table"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -709,10 +713,10 @@ func TestIsAuthorized(t *testing.T) { Name: "when-is-in", Policy: `permit(principal,action,resource) when { resource is Resource in Resource::"table" };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"Actor", "cuzco"}, - Action: EntityUID{"Action", "drop"}, - Resource: EntityUID{"Resource", "table"}, - Context: Record{}, + Principal: types.EntityUID{"Actor", "cuzco"}, + Action: types.EntityUID{"Action", "drop"}, + Resource: types.EntityUID{"Resource", "table"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -721,14 +725,14 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { resource is Resource in Parent::"id" };`, Entities: entitiesFromSlice([]Entity{ { - UID: EntityUID{"Resource", "table"}, - Parents: []EntityUID{{"Parent", "id"}}, + UID: types.EntityUID{"Resource", "table"}, + Parents: []types.EntityUID{{"Parent", "id"}}, }, }), - Principal: EntityUID{"Actor", "cuzco"}, - Action: EntityUID{"Action", "drop"}, - Resource: EntityUID{"Resource", "table"}, - Context: Record{}, + Principal: types.EntityUID{"Actor", "cuzco"}, + Action: types.EntityUID{"Action", "drop"}, + Resource: types.EntityUID{"Resource", "table"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -738,15 +742,15 @@ func TestIsAuthorized(t *testing.T) { t.Run(tt.Name, func(t *testing.T) { t.Parallel() ps, err := NewPolicySet("policy.cedar", []byte(tt.Policy)) - testutilOK(t, err) + testutil.OK(t, err) ok, diag := ps.IsAuthorized(tt.Entities, Request{ Principal: tt.Principal, Action: tt.Action, Resource: tt.Resource, Context: tt.Context, }) - testutilEquals(t, ok, tt.Want) - testutilEquals(t, len(diag.Errors), tt.DiagErr) + testutil.Equals(t, ok, tt.Want) + testutil.Equals(t, len(diag.Errors), tt.DiagErr) }) } } @@ -757,41 +761,41 @@ func TestEntities(t *testing.T) { t.Parallel() s := []Entity{ { - UID: EntityUID{Type: "A", ID: "A"}, + UID: types.EntityUID{Type: "A", ID: "A"}, }, { - UID: EntityUID{Type: "A", ID: "B"}, + UID: types.EntityUID{Type: "A", ID: "B"}, }, { - UID: EntityUID{Type: "B", ID: "A"}, + UID: types.EntityUID{Type: "B", ID: "A"}, }, { - UID: EntityUID{Type: "B", ID: "B"}, + UID: types.EntityUID{Type: "B", ID: "B"}, }, } entities := entitiesFromSlice(s) s2 := entities.toSlice() - testutilEquals(t, s2, s) + testutil.Equals(t, s2, s) }) t.Run("Clone", func(t *testing.T) { t.Parallel() s := []Entity{ { - UID: EntityUID{Type: "A", ID: "A"}, + UID: types.EntityUID{Type: "A", ID: "A"}, }, { - UID: EntityUID{Type: "A", ID: "B"}, + UID: types.EntityUID{Type: "A", ID: "B"}, }, { - UID: EntityUID{Type: "B", ID: "A"}, + UID: types.EntityUID{Type: "B", ID: "A"}, }, { - UID: EntityUID{Type: "B", ID: "B"}, + UID: types.EntityUID{Type: "B", ID: "B"}, }, } entities := entitiesFromSlice(s) clone := entities.Clone() - testutilEquals(t, clone, entities) + testutil.Equals(t, clone, entities) }) } @@ -800,37 +804,37 @@ func TestValueFrom(t *testing.T) { t.Parallel() tests := []struct { name string - in Value + in types.Value outJSON string }{ { name: "string", - in: String("hello"), + in: types.String("hello"), outJSON: `"hello"`, }, { name: "bool", - in: Boolean(true), + in: types.Boolean(true), outJSON: `true`, }, { name: "int64", - in: Long(42), + in: types.Long(42), outJSON: `42`, }, { name: "int64", - in: EntityUID{Type: "T", ID: "0"}, + in: types.EntityUID{Type: "T", ID: "0"}, outJSON: `{"__entity":{"type":"T","id":"0"}}`, }, { name: "record", - in: Record{"K": Boolean(true)}, + in: types.Record{"K": types.Boolean(true)}, outJSON: `{"K":true}`, }, { name: "netipPrefix", - in: IPAddr(netip.MustParsePrefix("192.168.0.42/32")), + in: types.IPAddr(netip.MustParsePrefix("192.168.0.42/32")), outJSON: `{"__extn":{"fn":"ip","arg":"192.168.0.42"}}`, }, } @@ -840,8 +844,8 @@ func TestValueFrom(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() out, err := tt.in.ExplicitMarshalJSON() - testutilOK(t, err) - testutilEquals(t, string(out), tt.outJSON) + testutil.OK(t, err) + testutil.Equals(t, string(out), tt.outJSON) }) } } @@ -849,7 +853,7 @@ func TestValueFrom(t *testing.T) { func TestError(t *testing.T) { t.Parallel() e := Error{Policy: 42, Message: "bad error"} - testutilEquals(t, e.String(), "while evaluating policy `policy42`: bad error") + testutil.Equals(t, e.String(), "while evaluating policy `policy42`: bad error") } func TestInvalidPolicy(t *testing.T) { @@ -858,12 +862,12 @@ func TestInvalidPolicy(t *testing.T) { ps := PolicySet{ { Effect: Forbid, - eval: newLiteralEval(Long(42)), + eval: newLiteralEval(types.Long(42)), }, } ok, diag := ps.IsAuthorized(Entities{}, Request{}) - testutilEquals(t, ok, Deny) - testutilEquals(t, diag, Diagnostic{ + testutil.Equals(t, ok, Deny) + testutil.Equals(t, diag, Diagnostic{ Errors: []Error{ { Policy: 0, @@ -892,7 +896,7 @@ func TestCorpusRelated(t *testing.T) { ) when { (true && (((!870985681610) == principal) == principal)) && principal };`, - Request{Principal: NewEntityUID("a", "\u0000\u0000"), Action: NewEntityUID("Action", "action"), Resource: NewEntityUID("a", "\u0000\u0000")}, + Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, Deny, nil, []int{0}, @@ -907,7 +911,7 @@ func TestCorpusRelated(t *testing.T) { ) when { (((!870985681610) == principal) == principal) };`, - Request{Principal: NewEntityUID("a", "\u0000\u0000"), Action: NewEntityUID("Action", "action"), Resource: NewEntityUID("a", "\u0000\u0000")}, + Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, Deny, nil, []int{0}, @@ -921,7 +925,7 @@ func TestCorpusRelated(t *testing.T) { ) when { ((!870985681610) == principal) };`, - Request{Principal: NewEntityUID("a", "\u0000\u0000"), Action: NewEntityUID("Action", "action"), Resource: NewEntityUID("a", "\u0000\u0000")}, + Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, Deny, nil, []int{0}, @@ -936,7 +940,7 @@ func TestCorpusRelated(t *testing.T) { ) when { (!870985681610) };`, - Request{Principal: NewEntityUID("a", "\u0000\u0000"), Action: NewEntityUID("Action", "action"), Resource: NewEntityUID("a", "\u0000\u0000")}, + Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, Deny, nil, []int{0}, @@ -980,7 +984,7 @@ func TestCorpusRelated(t *testing.T) { ) when { true && ((if (principal in action) then (ip("")) else (if true then (ip("6b6b:f00::32ff:ffff:6368/00")) else (ip("7265:6c69:706d:6f43:5f74:6f70:7374:6f68")))).isMulticast()) };`, - Request{Principal: NewEntityUID("a", "\u0000\b\u0011\u0000R"), Action: NewEntityUID("Action", "action"), Resource: NewEntityUID("a", "\u0000\b\u0011\u0000R")}, + Request{Principal: types.NewEntityUID("a", "\u0000\b\u0011\u0000R"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\b\u0011\u0000R")}, Deny, nil, []int{0}, @@ -1008,7 +1012,7 @@ func TestCorpusRelated(t *testing.T) { ) when { true && (([ip("c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5/68")].containsAll([ip("c5c5:c5c5:c5c5:c5c5:c5c5:5cc5:c5c5:c5c5/68")])) || ((ip("")) == (ip("")))) };`, - request: Request{Principal: NewEntityUID("a", "\u0000\u0000(W\u0000\u0000\u0000"), Action: NewEntityUID("Action", "action"), Resource: NewEntityUID("a", "")}, + request: Request{Principal: types.NewEntityUID("a", "\u0000\u0000(W\u0000\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "")}, decision: Deny, reasons: nil, errors: []int{0}, @@ -1019,19 +1023,123 @@ func TestCorpusRelated(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() policy, err := NewPolicySet("", []byte(tt.policy)) - testutilOK(t, err) + testutil.OK(t, err) ok, diag := policy.IsAuthorized(Entities{}, tt.request) - testutilEquals(t, ok, tt.decision) + testutil.Equals(t, ok, tt.decision) var reasons []int for _, n := range diag.Reasons { reasons = append(reasons, n.Policy) } - testutilEquals(t, reasons, tt.reasons) + testutil.Equals(t, reasons, tt.reasons) var errors []int for _, n := range diag.Errors { errors = append(errors, n.Policy) } - testutilEquals(t, errors, tt.errors) + testutil.Equals(t, errors, tt.errors) }) } } + +func TestEntitiesJSON(t *testing.T) { + t.Parallel() + t.Run("Marshal", func(t *testing.T) { + t.Parallel() + e := Entities{} + ent := Entity{ + UID: types.NewEntityUID("Type", "id"), + Parents: []types.EntityUID{}, + Attributes: types.Record{"key": types.Long(42)}, + } + e[ent.UID] = ent + b, err := e.MarshalJSON() + testutil.OK(t, err) + testutil.Equals(t, string(b), `[{"uid":{"type":"Type","id":"id"},"attrs":{"key":42}}]`) + }) + + t.Run("Unmarshal", func(t *testing.T) { + t.Parallel() + b := []byte(`[{"uid":{"type":"Type","id":"id"},"parents":[],"attrs":{"key":42}}]`) + var e Entities + err := json.Unmarshal(b, &e) + testutil.OK(t, err) + want := Entities{} + ent := Entity{ + UID: types.NewEntityUID("Type", "id"), + Parents: []types.EntityUID{}, + Attributes: types.Record{"key": types.Long(42)}, + } + want[ent.UID] = ent + testutil.Equals(t, e, want) + }) + + t.Run("UnmarshalErr", func(t *testing.T) { + t.Parallel() + var e Entities + err := e.UnmarshalJSON([]byte(`!@#$`)) + testutil.Error(t, err) + }) +} + +func TestJSONEffect(t *testing.T) { + t.Parallel() + t.Run("MarshalPermit", func(t *testing.T) { + t.Parallel() + e := Permit + b, err := e.MarshalJSON() + testutil.OK(t, err) + testutil.Equals(t, string(b), `"permit"`) + }) + t.Run("MarshalForbid", func(t *testing.T) { + t.Parallel() + e := Forbid + b, err := e.MarshalJSON() + testutil.OK(t, err) + testutil.Equals(t, string(b), `"forbid"`) + }) + t.Run("UnmarshalPermit", func(t *testing.T) { + t.Parallel() + var e Effect + err := json.Unmarshal([]byte(`"permit"`), &e) + testutil.OK(t, err) + testutil.Equals(t, e, Permit) + }) + t.Run("UnmarshalForbid", func(t *testing.T) { + t.Parallel() + var e Effect + err := json.Unmarshal([]byte(`"forbid"`), &e) + testutil.OK(t, err) + testutil.Equals(t, e, Forbid) + }) +} + +func TestJSONDecision(t *testing.T) { + t.Parallel() + t.Run("MarshalAllow", func(t *testing.T) { + t.Parallel() + d := Allow + b, err := d.MarshalJSON() + testutil.OK(t, err) + testutil.Equals(t, string(b), `"allow"`) + }) + t.Run("MarshalDeny", func(t *testing.T) { + t.Parallel() + d := Deny + b, err := d.MarshalJSON() + testutil.OK(t, err) + testutil.Equals(t, string(b), `"deny"`) + }) + t.Run("UnmarshalAllow", func(t *testing.T) { + t.Parallel() + var d Decision + err := json.Unmarshal([]byte(`"allow"`), &d) + testutil.OK(t, err) + testutil.Equals(t, d, Allow) + }) + t.Run("UnmarshalDeny", func(t *testing.T) { + t.Parallel() + var d Decision + err := json.Unmarshal([]byte(`"deny"`), &d) + testutil.OK(t, err) + testutil.Equals(t, d, Deny) + }) +} diff --git a/corpus_test.go b/corpus_test.go index e27353d9..305c54f9 100644 --- a/corpus_test.go +++ b/corpus_test.go @@ -11,18 +11,20 @@ import ( "slices" "strings" "testing" + + "github.com/cedar-policy/cedar-go/types" ) // jsonEntity is not part of entityValue as I can find // no evidence this is part of the JSON spec. It also // requires creating a parser, so it's quite expensive. -type jsonEntity EntityUID +type jsonEntity types.EntityUID func (e *jsonEntity) UnmarshalJSON(b []byte) error { if string(b) == "null" { return nil } - var input EntityUID + var input types.EntityUID if err := json.Unmarshal(b, &input); err != nil { return err } @@ -36,14 +38,14 @@ type corpusTest struct { ShouldValidate bool `json:"shouldValidate"` Entities string `json:"entities"` Requests []struct { - Desc string `json:"description"` - Principal jsonEntity `json:"principal"` - Action jsonEntity `json:"action"` - Resource jsonEntity `json:"resource"` - Context Record `json:"context"` - Decision string `json:"decision"` - Reasons []string `json:"reason"` - Errors []string `json:"errors"` + Desc string `json:"description"` + Principal jsonEntity `json:"principal"` + Action jsonEntity `json:"action"` + Resource jsonEntity `json:"resource"` + Context types.Record `json:"context"` + Decision string `json:"decision"` + Reasons []string `json:"reason"` + Errors []string `json:"errors"` } `json:"requests"` } @@ -157,9 +159,9 @@ func TestCorpus(t *testing.T) { ok, diag := policySet.IsAuthorized( entities, Request{ - Principal: EntityUID(request.Principal), - Action: EntityUID(request.Action), - Resource: EntityUID(request.Resource), + Principal: types.EntityUID(request.Principal), + Action: types.EntityUID(request.Action), + Resource: types.EntityUID(request.Resource), Context: request.Context, }) diff --git a/eval.go b/eval.go index 30a739a7..b38159c5 100644 --- a/eval.go +++ b/eval.go @@ -3,194 +3,120 @@ package cedar import ( "fmt" + "github.com/cedar-policy/cedar-go/types" "github.com/cedar-policy/cedar-go/x/exp/parser" ) var errOverflow = fmt.Errorf("integer overflow") -var errType = fmt.Errorf("type error") var errUnknownMethod = fmt.Errorf("unknown method") var errUnknownExtensionFunction = fmt.Errorf("function does not exist") var errArity = fmt.Errorf("wrong number of arguments provided to extension function") var errAttributeAccess = fmt.Errorf("does not have the attribute") -var errDecimal = fmt.Errorf("error parsing decimal value") -var errIP = fmt.Errorf("error parsing ip value") var errEntityNotExist = fmt.Errorf("does not exist") var errUnspecifiedEntity = fmt.Errorf("unspecified entity") type evalContext struct { Entities Entities - Principal, Action, Resource Value - Context Value + Principal, Action, Resource types.Value + Context types.Value } type evaler interface { - Eval(*evalContext) (Value, error) + Eval(*evalContext) (types.Value, error) } -func valueToBool(v Value) (Boolean, error) { - bv, ok := v.(Boolean) - if !ok { - return false, fmt.Errorf("%w: expected bool, got %v", errType, v.typeName()) - } - return bv, nil -} - -func evalBool(n evaler, ctx *evalContext) (Boolean, error) { +func evalBool(n evaler, ctx *evalContext) (types.Boolean, error) { v, err := n.Eval(ctx) if err != nil { return false, err } - b, err := valueToBool(v) + b, err := types.ValueToBool(v) if err != nil { return false, err } return b, nil } -func valueToLong(v Value) (Long, error) { - lv, ok := v.(Long) - if !ok { - return 0, fmt.Errorf("%w: expected long, got %v", errType, v.typeName()) - } - return lv, nil -} - -func evalLong(n evaler, ctx *evalContext) (Long, error) { +func evalLong(n evaler, ctx *evalContext) (types.Long, error) { v, err := n.Eval(ctx) if err != nil { return 0, err } - l, err := valueToLong(v) + l, err := types.ValueToLong(v) if err != nil { return 0, err } return l, nil } -func valueToString(v Value) (String, error) { - sv, ok := v.(String) - if !ok { - return "", fmt.Errorf("%w: expected string, got %v", errType, v.typeName()) - } - return sv, nil -} - -func evalString(n evaler, ctx *evalContext) (String, error) { +func evalString(n evaler, ctx *evalContext) (types.String, error) { v, err := n.Eval(ctx) if err != nil { return "", err } - s, err := valueToString(v) + s, err := types.ValueToString(v) if err != nil { return "", err } return s, nil } -func valueToSet(v Value) (Set, error) { - sv, ok := v.(Set) - if !ok { - return nil, fmt.Errorf("%w: expected set, got %v", errType, v.typeName()) - } - return sv, nil -} - -func evalSet(n evaler, ctx *evalContext) (Set, error) { +func evalSet(n evaler, ctx *evalContext) (types.Set, error) { v, err := n.Eval(ctx) if err != nil { return nil, err } - s, err := valueToSet(v) + s, err := types.ValueToSet(v) if err != nil { return nil, err } return s, nil } -func valueToRecord(v Value) (Record, error) { - rv, ok := v.(Record) - if !ok { - return nil, fmt.Errorf("%w: expected record got %v", errType, v.typeName()) - } - return rv, nil -} - -func valueToEntity(v Value) (EntityUID, error) { - ev, ok := v.(EntityUID) - if !ok { - return EntityUID{}, fmt.Errorf("%w: expected (entity of type `any_entity_type`), got %v", errType, v.typeName()) - } - return ev, nil -} - -func valueToPath(v Value) (path, error) { - ev, ok := v.(path) - if !ok { - return "", fmt.Errorf("%w: expected (path of type `any_entity_type`), got %v", errType, v.typeName()) - } - return ev, nil -} - -func evalEntity(n evaler, ctx *evalContext) (EntityUID, error) { +func evalEntity(n evaler, ctx *evalContext) (types.EntityUID, error) { v, err := n.Eval(ctx) if err != nil { - return EntityUID{}, err + return types.EntityUID{}, err } - e, err := valueToEntity(v) + e, err := types.ValueToEntity(v) if err != nil { - return EntityUID{}, err + return types.EntityUID{}, err } return e, nil } -func evalPath(n evaler, ctx *evalContext) (path, error) { +func evalPath(n evaler, ctx *evalContext) (types.Path, error) { v, err := n.Eval(ctx) if err != nil { return "", err } - e, err := valueToPath(v) + e, err := types.ValueToPath(v) if err != nil { return "", err } return e, nil } -func valueToDecimal(v Value) (Decimal, error) { - d, ok := v.(Decimal) - if !ok { - return 0, fmt.Errorf("%w: expected decimal, got %v", errType, v.typeName()) - } - return d, nil -} - -func evalDecimal(n evaler, ctx *evalContext) (Decimal, error) { +func evalDecimal(n evaler, ctx *evalContext) (types.Decimal, error) { v, err := n.Eval(ctx) if err != nil { - return Decimal(0), err + return types.Decimal(0), err } - d, err := valueToDecimal(v) + d, err := types.ValueToDecimal(v) if err != nil { - return Decimal(0), err + return types.Decimal(0), err } return d, nil } -func valueToIP(v Value) (IPAddr, error) { - i, ok := v.(IPAddr) - if !ok { - return IPAddr{}, fmt.Errorf("%w: expected ipaddr, got %v", errType, v.typeName()) - } - return i, nil -} - -func evalIP(n evaler, ctx *evalContext) (IPAddr, error) { +func evalIP(n evaler, ctx *evalContext) (types.IPAddr, error) { v, err := n.Eval(ctx) if err != nil { - return IPAddr{}, err + return types.IPAddr{}, err } - i, err := valueToIP(v) + i, err := types.ValueToIP(v) if err != nil { - return IPAddr{}, err + return types.IPAddr{}, err } return i, nil } @@ -206,20 +132,20 @@ func newErrorEval(err error) *errorEval { } } -func (n *errorEval) Eval(_ *evalContext) (Value, error) { - return zeroValue(), n.err +func (n *errorEval) Eval(_ *evalContext) (types.Value, error) { + return types.ZeroValue(), n.err } // literalEval type literalEval struct { - value Value + value types.Value } -func newLiteralEval(value Value) *literalEval { +func newLiteralEval(value types.Value) *literalEval { return &literalEval{value: value} } -func (n *literalEval) Eval(_ *evalContext) (Value, error) { +func (n *literalEval) Eval(_ *evalContext) (types.Value, error) { return n.value, nil } @@ -236,25 +162,25 @@ func newOrNode(lhs evaler, rhs evaler) *orEval { } } -func (n *orEval) Eval(ctx *evalContext) (Value, error) { +func (n *orEval) Eval(ctx *evalContext) (types.Value, error) { v, err := n.lhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - b, err := valueToBool(v) + b, err := types.ValueToBool(v) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } if b { return v, nil } v, err = n.rhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - _, err = valueToBool(v) + _, err = types.ValueToBool(v) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } return v, nil } @@ -272,25 +198,25 @@ func newAndEval(lhs evaler, rhs evaler) *andEval { } } -func (n *andEval) Eval(ctx *evalContext) (Value, error) { +func (n *andEval) Eval(ctx *evalContext) (types.Value, error) { v, err := n.lhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - b, err := valueToBool(v) + b, err := types.ValueToBool(v) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } if !b { return v, nil } v, err = n.rhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - _, err = valueToBool(v) + _, err = types.ValueToBool(v) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } return v, nil } @@ -306,14 +232,14 @@ func newNotEval(inner evaler) *notEval { } } -func (n *notEval) Eval(ctx *evalContext) (Value, error) { +func (n *notEval) Eval(ctx *evalContext) (types.Value, error) { v, err := n.inner.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - b, err := valueToBool(v) + b, err := types.ValueToBool(v) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } return !b, nil } @@ -323,7 +249,7 @@ func (n *notEval) Eval(ctx *evalContext) (Value, error) { // behavior (https://go.dev/ref/spec#Integer_overflow), so we can go ahead and // do the operations and then check for overflow ex post facto. -func checkedAddI64(lhs, rhs Long) (Long, bool) { +func checkedAddI64(lhs, rhs types.Long) (types.Long, bool) { result := lhs + rhs if (result > lhs) != (rhs > 0) { return result, false @@ -331,7 +257,7 @@ func checkedAddI64(lhs, rhs Long) (Long, bool) { return result, true } -func checkedSubI64(lhs, rhs Long) (Long, bool) { +func checkedSubI64(lhs, rhs types.Long) (types.Long, bool) { result := lhs - rhs if (result > lhs) != (rhs < 0) { return result, false @@ -339,7 +265,7 @@ func checkedSubI64(lhs, rhs Long) (Long, bool) { return result, true } -func checkedMulI64(lhs, rhs Long) (Long, bool) { +func checkedMulI64(lhs, rhs types.Long) (types.Long, bool) { if lhs == 0 || rhs == 0 { return 0, true } @@ -355,7 +281,7 @@ func checkedMulI64(lhs, rhs Long) (Long, bool) { return result, true } -func checkedNegI64(a Long) (Long, bool) { +func checkedNegI64(a types.Long) (types.Long, bool) { if a == -9_223_372_036_854_775_808 { return 0, false } @@ -375,18 +301,18 @@ func newAddEval(lhs evaler, rhs evaler) *addEval { } } -func (n *addEval) Eval(ctx *evalContext) (Value, error) { +func (n *addEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } res, ok := checkedAddI64(lhs, rhs) if !ok { - return zeroValue(), fmt.Errorf("%w while attempting to add `%d` with `%d`", errOverflow, lhs, rhs) + return types.ZeroValue(), fmt.Errorf("%w while attempting to add `%d` with `%d`", errOverflow, lhs, rhs) } return res, nil } @@ -404,18 +330,18 @@ func newSubtractEval(lhs evaler, rhs evaler) *subtractEval { } } -func (n *subtractEval) Eval(ctx *evalContext) (Value, error) { +func (n *subtractEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } res, ok := checkedSubI64(lhs, rhs) if !ok { - return zeroValue(), fmt.Errorf("%w while attempting to subtract `%d` from `%d`", errOverflow, rhs, lhs) + return types.ZeroValue(), fmt.Errorf("%w while attempting to subtract `%d` from `%d`", errOverflow, rhs, lhs) } return res, nil } @@ -433,18 +359,18 @@ func newMultiplyEval(lhs evaler, rhs evaler) *multiplyEval { } } -func (n *multiplyEval) Eval(ctx *evalContext) (Value, error) { +func (n *multiplyEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } res, ok := checkedMulI64(lhs, rhs) if !ok { - return zeroValue(), fmt.Errorf("%w while attempting to multiply `%d` by `%d`", errOverflow, lhs, rhs) + return types.ZeroValue(), fmt.Errorf("%w while attempting to multiply `%d` by `%d`", errOverflow, lhs, rhs) } return res, nil } @@ -460,14 +386,14 @@ func newNegateEval(inner evaler) *negateEval { } } -func (n *negateEval) Eval(ctx *evalContext) (Value, error) { +func (n *negateEval) Eval(ctx *evalContext) (types.Value, error) { inner, err := evalLong(n.inner, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } res, ok := checkedNegI64(inner) if !ok { - return zeroValue(), fmt.Errorf("%w while attempting to negate `%d`", errOverflow, inner) + return types.ZeroValue(), fmt.Errorf("%w while attempting to negate `%d`", errOverflow, inner) } return res, nil } @@ -485,16 +411,16 @@ func newLongLessThanEval(lhs evaler, rhs evaler) *longLessThanEval { } } -func (n *longLessThanEval) Eval(ctx *evalContext) (Value, error) { +func (n *longLessThanEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs < rhs), nil + return types.Boolean(lhs < rhs), nil } // longLessThanOrEqualEval @@ -510,16 +436,16 @@ func newLongLessThanOrEqualEval(lhs evaler, rhs evaler) *longLessThanOrEqualEval } } -func (n *longLessThanOrEqualEval) Eval(ctx *evalContext) (Value, error) { +func (n *longLessThanOrEqualEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs <= rhs), nil + return types.Boolean(lhs <= rhs), nil } // longGreaterThanEval @@ -535,16 +461,16 @@ func newLongGreaterThanEval(lhs evaler, rhs evaler) *longGreaterThanEval { } } -func (n *longGreaterThanEval) Eval(ctx *evalContext) (Value, error) { +func (n *longGreaterThanEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs > rhs), nil + return types.Boolean(lhs > rhs), nil } // longGreaterThanOrEqualEval @@ -560,16 +486,16 @@ func newLongGreaterThanOrEqualEval(lhs evaler, rhs evaler) *longGreaterThanOrEqu } } -func (n *longGreaterThanOrEqualEval) Eval(ctx *evalContext) (Value, error) { +func (n *longGreaterThanOrEqualEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs >= rhs), nil + return types.Boolean(lhs >= rhs), nil } // decimalLessThanEval @@ -585,16 +511,16 @@ func newDecimalLessThanEval(lhs evaler, rhs evaler) *decimalLessThanEval { } } -func (n *decimalLessThanEval) Eval(ctx *evalContext) (Value, error) { +func (n *decimalLessThanEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalDecimal(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalDecimal(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs < rhs), nil + return types.Boolean(lhs < rhs), nil } // decimalLessThanOrEqualEval @@ -610,16 +536,16 @@ func newDecimalLessThanOrEqualEval(lhs evaler, rhs evaler) *decimalLessThanOrEqu } } -func (n *decimalLessThanOrEqualEval) Eval(ctx *evalContext) (Value, error) { +func (n *decimalLessThanOrEqualEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalDecimal(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalDecimal(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs <= rhs), nil + return types.Boolean(lhs <= rhs), nil } // decimalGreaterThanEval @@ -635,16 +561,16 @@ func newDecimalGreaterThanEval(lhs evaler, rhs evaler) *decimalGreaterThanEval { } } -func (n *decimalGreaterThanEval) Eval(ctx *evalContext) (Value, error) { +func (n *decimalGreaterThanEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalDecimal(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalDecimal(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs > rhs), nil + return types.Boolean(lhs > rhs), nil } // decimalGreaterThanOrEqualEval @@ -660,16 +586,16 @@ func newDecimalGreaterThanOrEqualEval(lhs evaler, rhs evaler) *decimalGreaterTha } } -func (n *decimalGreaterThanOrEqualEval) Eval(ctx *evalContext) (Value, error) { +func (n *decimalGreaterThanOrEqualEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalDecimal(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalDecimal(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs >= rhs), nil + return types.Boolean(lhs >= rhs), nil } // ifThenElseEval @@ -687,10 +613,10 @@ func newIfThenElseEval(if_, then, else_ evaler) *ifThenElseEval { } } -func (n *ifThenElseEval) Eval(ctx *evalContext) (Value, error) { +func (n *ifThenElseEval) Eval(ctx *evalContext) (types.Value, error) { cond, err := evalBool(n.if_, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } if cond { return n.then.Eval(ctx) @@ -710,16 +636,16 @@ func newEqualEval(lhs, rhs evaler) *equalEval { } } -func (n *equalEval) Eval(ctx *evalContext) (Value, error) { +func (n *equalEval) Eval(ctx *evalContext) (types.Value, error) { lv, err := n.lhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rv, err := n.rhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lv.equal(rv)), nil + return types.Boolean(lv.Equal(rv)), nil } // notEqualEval @@ -734,16 +660,16 @@ func newNotEqualEval(lhs, rhs evaler) *notEqualEval { } } -func (n *notEqualEval) Eval(ctx *evalContext) (Value, error) { +func (n *notEqualEval) Eval(ctx *evalContext) (types.Value, error) { lv, err := n.lhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rv, err := n.rhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(!lv.equal(rv)), nil + return types.Boolean(!lv.Equal(rv)), nil } // setLiteralEval @@ -755,12 +681,12 @@ func newSetLiteralEval(elements []evaler) *setLiteralEval { return &setLiteralEval{elements: elements} } -func (n *setLiteralEval) Eval(ctx *evalContext) (Value, error) { - var vals Set +func (n *setLiteralEval) Eval(ctx *evalContext) (types.Value, error) { + var vals types.Set for _, e := range n.elements { v, err := e.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } vals = append(vals, v) } @@ -779,16 +705,16 @@ func newContainsEval(lhs, rhs evaler) *containsEval { } } -func (n *containsEval) Eval(ctx *evalContext) (Value, error) { +func (n *containsEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalSet(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := n.rhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs.contains(rhs)), nil + return types.Boolean(lhs.Contains(rhs)), nil } // containsAllEval @@ -803,23 +729,23 @@ func newContainsAllEval(lhs, rhs evaler) *containsAllEval { } } -func (n *containsAllEval) Eval(ctx *evalContext) (Value, error) { +func (n *containsAllEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalSet(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalSet(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } result := true for _, e := range rhs { - if !lhs.contains(e) { + if !lhs.Contains(e) { result = false break } } - return Boolean(result), nil + return types.Boolean(result), nil } // containsAnyEval @@ -834,23 +760,23 @@ func newContainsAnyEval(lhs, rhs evaler) *containsAnyEval { } } -func (n *containsAnyEval) Eval(ctx *evalContext) (Value, error) { +func (n *containsAnyEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalSet(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalSet(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } result := false for _, e := range rhs { - if lhs.contains(e) { + if lhs.Contains(e) { result = true break } } - return Boolean(result), nil + return types.Boolean(result), nil } // recordLiteralEval @@ -862,12 +788,12 @@ func newRecordLiteralEval(elements map[string]evaler) *recordLiteralEval { return &recordLiteralEval{elements: elements} } -func (n *recordLiteralEval) Eval(ctx *evalContext) (Value, error) { - vals := Record{} +func (n *recordLiteralEval) Eval(ctx *evalContext) (types.Value, error) { + vals := types.Record{} for k, en := range n.elements { v, err := en.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } vals[k] = v } @@ -884,34 +810,34 @@ func newAttributeAccessEval(record evaler, attribute string) *attributeAccessEva return &attributeAccessEval{object: record, attribute: attribute} } -func (n *attributeAccessEval) Eval(ctx *evalContext) (Value, error) { +func (n *attributeAccessEval) Eval(ctx *evalContext) (types.Value, error) { v, err := n.object.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - var record Record + var record types.Record key := "record" switch vv := v.(type) { - case EntityUID: + case types.EntityUID: key = "`" + vv.String() + "`" - var unspecified EntityUID + var unspecified types.EntityUID if vv == unspecified { - return zeroValue(), fmt.Errorf("cannot access attribute `%s` of %w", n.attribute, errUnspecifiedEntity) + return types.ZeroValue(), fmt.Errorf("cannot access attribute `%s` of %w", n.attribute, errUnspecifiedEntity) } rec, ok := ctx.Entities[vv] if !ok { - return zeroValue(), fmt.Errorf("entity `%v` %w", vv.String(), errEntityNotExist) + return types.ZeroValue(), fmt.Errorf("entity `%v` %w", vv.String(), errEntityNotExist) } else { record = rec.Attributes } - case Record: + case types.Record: record = vv default: - return zeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", errType, v.typeName()) + return types.ZeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", types.ErrType, v.TypeName()) } val, ok := record[n.attribute] if !ok { - return zeroValue(), fmt.Errorf("%s %w `%s`", key, errAttributeAccess, n.attribute) + return types.ZeroValue(), fmt.Errorf("%s %w `%s`", key, errAttributeAccess, n.attribute) } return val, nil } @@ -926,27 +852,27 @@ func newHasEval(record evaler, attribute string) *hasEval { return &hasEval{object: record, attribute: attribute} } -func (n *hasEval) Eval(ctx *evalContext) (Value, error) { +func (n *hasEval) Eval(ctx *evalContext) (types.Value, error) { v, err := n.object.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - var record Record + var record types.Record switch vv := v.(type) { - case EntityUID: + case types.EntityUID: rec, ok := ctx.Entities[vv] if !ok { - record = Record{} + record = types.Record{} } else { record = rec.Attributes } - case Record: + case types.Record: record = vv default: - return zeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", errType, v.typeName()) + return types.ZeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", types.ErrType, v.TypeName()) } _, ok := record[n.attribute] - return Boolean(ok), nil + return types.Boolean(ok), nil } // likeEval @@ -959,20 +885,20 @@ func newLikeEval(lhs evaler, pattern parser.Pattern) *likeEval { return &likeEval{lhs: lhs, pattern: pattern} } -func (l *likeEval) Eval(ctx *evalContext) (Value, error) { +func (l *likeEval) Eval(ctx *evalContext) (types.Value, error) { v, err := evalString(l.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(match(l.pattern, string(v))), nil + return types.Boolean(match(l.pattern, string(v))), nil } -type variableName func(ctx *evalContext) Value +type variableName func(ctx *evalContext) types.Value -func variableNamePrincipal(ctx *evalContext) Value { return ctx.Principal } -func variableNameAction(ctx *evalContext) Value { return ctx.Action } -func variableNameResource(ctx *evalContext) Value { return ctx.Resource } -func variableNameContext(ctx *evalContext) Value { return ctx.Context } +func variableNamePrincipal(ctx *evalContext) types.Value { return ctx.Principal } +func variableNameAction(ctx *evalContext) types.Value { return ctx.Action } +func variableNameResource(ctx *evalContext) types.Value { return ctx.Resource } +func variableNameContext(ctx *evalContext) types.Value { return ctx.Context } // variableEval type variableEval struct { @@ -983,7 +909,7 @@ func newVariableEval(variableName variableName) *variableEval { return &variableEval{variableName: variableName} } -func (n *variableEval) Eval(ctx *evalContext) (Value, error) { +func (n *variableEval) Eval(ctx *evalContext) (types.Value, error) { return n.variableName(ctx), nil } @@ -996,11 +922,11 @@ func newInEval(lhs, rhs evaler) *inEval { return &inEval{lhs: lhs, rhs: rhs} } -func entityIn(entity EntityUID, query map[EntityUID]struct{}, entities Entities) bool { - checked := map[EntityUID]struct{}{} - toCheck := []EntityUID{entity} +func entityIn(entity types.EntityUID, query map[types.EntityUID]struct{}, entities Entities) bool { + checked := map[types.EntityUID]struct{}{} + toCheck := []types.EntityUID{entity} for len(toCheck) > 0 { - var candidate EntityUID + var candidate types.EntityUID candidate, toCheck = toCheck[len(toCheck)-1], toCheck[:len(toCheck)-1] if _, ok := checked[candidate]; ok { continue @@ -1014,34 +940,34 @@ func entityIn(entity EntityUID, query map[EntityUID]struct{}, entities Entities) return false } -func (n *inEval) Eval(ctx *evalContext) (Value, error) { +func (n *inEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalEntity(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := n.rhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - query := map[EntityUID]struct{}{} + query := map[types.EntityUID]struct{}{} switch rhsv := rhs.(type) { - case EntityUID: + case types.EntityUID: query[rhsv] = struct{}{} - case Set: + case types.Set: for _, rhv := range rhsv { - e, err := valueToEntity(rhv) + e, err := types.ValueToEntity(rhv) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } query[e] = struct{}{} } default: - return zeroValue(), fmt.Errorf( - "%w: expected one of [set, (entity of type `any_entity_type`)], got %v", errType, rhs.typeName()) + return types.ZeroValue(), fmt.Errorf( + "%w: expected one of [set, (entity of type `any_entity_type`)], got %v", types.ErrType, rhs.TypeName()) } - return Boolean(entityIn(lhs, query, ctx.Entities)), nil + return types.Boolean(entityIn(lhs, query, ctx.Entities)), nil } // isEval @@ -1053,18 +979,18 @@ func newIsEval(lhs, rhs evaler) *isEval { return &isEval{lhs: lhs, rhs: rhs} } -func (n *isEval) Eval(ctx *evalContext) (Value, error) { +func (n *isEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalEntity(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalPath(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(path(lhs.Type) == rhs), nil + return types.Boolean(types.Path(lhs.Type) == rhs), nil } // decimalLiteralEval @@ -1076,15 +1002,15 @@ func newDecimalLiteralEval(literal evaler) *decimalLiteralEval { return &decimalLiteralEval{literal: literal} } -func (n *decimalLiteralEval) Eval(ctx *evalContext) (Value, error) { +func (n *decimalLiteralEval) Eval(ctx *evalContext) (types.Value, error) { literal, err := evalString(n.literal, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - d, err := ParseDecimal(string(literal)) + d, err := types.ParseDecimal(string(literal)) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } return d, nil @@ -1098,26 +1024,26 @@ func newIPLiteralEval(literal evaler) *ipLiteralEval { return &ipLiteralEval{literal: literal} } -func (n *ipLiteralEval) Eval(ctx *evalContext) (Value, error) { +func (n *ipLiteralEval) Eval(ctx *evalContext) (types.Value, error) { literal, err := evalString(n.literal, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - i, err := ParseIPAddr(string(literal)) + i, err := types.ParseIPAddr(string(literal)) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } return i, nil } -type ipTestType func(v IPAddr) bool +type ipTestType func(v types.IPAddr) bool -func ipTestIPv4(v IPAddr) bool { return v.isIPv4() } -func ipTestIPv6(v IPAddr) bool { return v.isIPv6() } -func ipTestLoopback(v IPAddr) bool { return v.isLoopback() } -func ipTestMulticast(v IPAddr) bool { return v.isMulticast() } +func ipTestIPv4(v types.IPAddr) bool { return v.IsIPv4() } +func ipTestIPv6(v types.IPAddr) bool { return v.IsIPv6() } +func ipTestLoopback(v types.IPAddr) bool { return v.IsLoopback() } +func ipTestMulticast(v types.IPAddr) bool { return v.IsMulticast() } // ipTestEval type ipTestEval struct { @@ -1129,12 +1055,12 @@ func newIPTestEval(object evaler, test ipTestType) *ipTestEval { return &ipTestEval{object: object, test: test} } -func (n *ipTestEval) Eval(ctx *evalContext) (Value, error) { +func (n *ipTestEval) Eval(ctx *evalContext) (types.Value, error) { i, err := evalIP(n.object, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(n.test(i)), nil + return types.Boolean(n.test(i)), nil } // ipIsInRangeEval @@ -1147,14 +1073,14 @@ func newIPIsInRangeEval(lhs, rhs evaler) *ipIsInRangeEval { return &ipIsInRangeEval{lhs: lhs, rhs: rhs} } -func (n *ipIsInRangeEval) Eval(ctx *evalContext) (Value, error) { +func (n *ipIsInRangeEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalIP(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalIP(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(rhs.contains(lhs)), nil + return types.Boolean(rhs.Contains(lhs)), nil } diff --git a/eval_test.go b/eval_test.go index 6bceca3a..bd5027d3 100644 --- a/eval_test.go +++ b/eval_test.go @@ -6,15 +6,17 @@ import ( "strings" "testing" + "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/types" "github.com/cedar-policy/cedar-go/x/exp/parser" ) var errTest = fmt.Errorf("test error") // not a real parser -func strEnt(v string) EntityUID { +func strEnt(v string) types.EntityUID { p := strings.Split(v, "::\"") - return EntityUID{Type: p[0], ID: p[1][:len(p[1])-1]} + return types.EntityUID{Type: p[0], ID: p[1][:len(p[1])-1]} } func TestOrNode(t *testing.T) { @@ -32,10 +34,10 @@ func TestOrNode(t *testing.T) { tt := tt t.Run(fmt.Sprintf("%v%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() - n := newOrNode(newLiteralEval(Boolean(tt.lhs)), newLiteralEval(Boolean(tt.rhs))) + n := newOrNode(newLiteralEval(types.Boolean(tt.lhs)), newLiteralEval(types.Boolean(tt.rhs))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -43,10 +45,10 @@ func TestOrNode(t *testing.T) { t.Run("TrueXShortCircuit", func(t *testing.T) { t.Parallel() n := newOrNode( - newLiteralEval(Boolean(true)), newLiteralEval(Long(1))) + newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(1))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, true) + testutil.OK(t, err) + types.AssertBoolValue(t, v, true) }) { @@ -55,10 +57,10 @@ func TestOrNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Boolean(true)), errTest}, - {"LhsTypeError", newLiteralEval(Long(1)), newLiteralEval(Boolean(true)), errType}, - {"RhsError", newLiteralEval(Boolean(false)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Boolean(false)), newLiteralEval(Long(1)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Boolean(true)), errTest}, + {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(types.Boolean(true)), types.ErrType}, + {"RhsError", newLiteralEval(types.Boolean(false)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Boolean(false)), newLiteralEval(types.Long(1)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -66,7 +68,7 @@ func TestOrNode(t *testing.T) { t.Parallel() n := newOrNode(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -87,10 +89,10 @@ func TestAndNode(t *testing.T) { tt := tt t.Run(fmt.Sprintf("%v%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() - n := newAndEval(newLiteralEval(Boolean(tt.lhs)), newLiteralEval(Boolean(tt.rhs))) + n := newAndEval(newLiteralEval(types.Boolean(tt.lhs)), newLiteralEval(types.Boolean(tt.rhs))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -98,10 +100,10 @@ func TestAndNode(t *testing.T) { t.Run("FalseXShortCircuit", func(t *testing.T) { t.Parallel() n := newAndEval( - newLiteralEval(Boolean(false)), newLiteralEval(Long(1))) + newLiteralEval(types.Boolean(false)), newLiteralEval(types.Long(1))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, false) + testutil.OK(t, err) + types.AssertBoolValue(t, v, false) }) { @@ -110,10 +112,10 @@ func TestAndNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Boolean(true)), errTest}, - {"LhsTypeError", newLiteralEval(Long(1)), newLiteralEval(Boolean(true)), errType}, - {"RhsError", newLiteralEval(Boolean(true)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(1)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Boolean(true)), errTest}, + {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(types.Boolean(true)), types.ErrType}, + {"RhsError", newLiteralEval(types.Boolean(true)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(1)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -121,7 +123,7 @@ func TestAndNode(t *testing.T) { t.Parallel() n := newAndEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -140,10 +142,10 @@ func TestNotNode(t *testing.T) { tt := tt t.Run(fmt.Sprintf("%v", tt.arg), func(t *testing.T) { t.Parallel() - n := newNotEval(newLiteralEval(Boolean(tt.arg))) + n := newNotEval(newLiteralEval(types.Boolean(tt.arg))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -155,7 +157,7 @@ func TestNotNode(t *testing.T) { err error }{ {"Error", newErrorEval(errTest), errTest}, - {"TypeError", newLiteralEval(Long(1)), errType}, + {"TypeError", newLiteralEval(types.Long(1)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -163,7 +165,7 @@ func TestNotNode(t *testing.T) { t.Parallel() n := newNotEval(tt.arg) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -172,7 +174,7 @@ func TestNotNode(t *testing.T) { func TestCheckedAddI64(t *testing.T) { t.Parallel() tests := []struct { - lhs, rhs, result Long + lhs, rhs, result types.Long ok bool }{ {1, 1, 2, true}, @@ -198,8 +200,8 @@ func TestCheckedAddI64(t *testing.T) { t.Run(fmt.Sprintf("%v+%v=%v(%v)", tt.lhs, tt.rhs, tt.result, tt.ok), func(t *testing.T) { t.Parallel() result, ok := checkedAddI64(tt.lhs, tt.rhs) - testutilEquals(t, ok, tt.ok) - testutilEquals(t, result, tt.result) + testutil.Equals(t, ok, tt.ok) + testutil.Equals(t, result, tt.result) }) } } @@ -207,7 +209,7 @@ func TestCheckedAddI64(t *testing.T) { func TestCheckedSubI64(t *testing.T) { t.Parallel() tests := []struct { - lhs, rhs, result Long + lhs, rhs, result types.Long ok bool }{ {1, 1, 0, true}, @@ -233,8 +235,8 @@ func TestCheckedSubI64(t *testing.T) { t.Run(fmt.Sprintf("%v-%v=%v(%v)", tt.lhs, tt.rhs, tt.result, tt.ok), func(t *testing.T) { t.Parallel() result, ok := checkedSubI64(tt.lhs, tt.rhs) - testutilEquals(t, ok, tt.ok) - testutilEquals(t, result, tt.result) + testutil.Equals(t, ok, tt.ok) + testutil.Equals(t, result, tt.result) }) } } @@ -242,7 +244,7 @@ func TestCheckedSubI64(t *testing.T) { func TestCheckedMulI64(t *testing.T) { t.Parallel() tests := []struct { - lhs, rhs, result Long + lhs, rhs, result types.Long ok bool }{ {2, 3, 6, true}, @@ -307,8 +309,8 @@ func TestCheckedMulI64(t *testing.T) { t.Run(fmt.Sprintf("%v*%v=%v(%v)", tt.lhs, tt.rhs, tt.result, tt.ok), func(t *testing.T) { t.Parallel() result, ok := checkedMulI64(tt.lhs, tt.rhs) - testutilEquals(t, ok, tt.ok) - testutilEquals(t, result, tt.result) + testutil.Equals(t, ok, tt.ok) + testutil.Equals(t, result, tt.result) }) } } @@ -316,7 +318,7 @@ func TestCheckedMulI64(t *testing.T) { func TestCheckedNegI64(t *testing.T) { t.Parallel() tests := []struct { - arg, result Long + arg, result types.Long ok bool }{ {2, -2, true}, @@ -331,8 +333,8 @@ func TestCheckedNegI64(t *testing.T) { t.Run(fmt.Sprintf("-%v*=%v(%v)", tt.arg, tt.result, tt.ok), func(t *testing.T) { t.Parallel() result, ok := checkedNegI64(tt.arg) - testutilEquals(t, ok, tt.ok) - testutilEquals(t, result, tt.result) + testutil.Equals(t, ok, tt.ok) + testutil.Equals(t, result, tt.result) }) } } @@ -341,10 +343,10 @@ func TestAddNode(t *testing.T) { t.Parallel() t.Run("Basic", func(t *testing.T) { t.Parallel() - n := newAddEval(newLiteralEval(Long(1)), newLiteralEval(Long(2))) + n := newAddEval(newLiteralEval(types.Long(1)), newLiteralEval(types.Long(2))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertLongValue(t, v, 3) + testutil.OK(t, err) + types.AssertLongValue(t, v, 3) }) tests := []struct { @@ -352,17 +354,17 @@ func TestAddNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(0)), errType}, - {"RhsError", newLiteralEval(Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Long(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, {"PositiveOverflow", - newLiteralEval(Long(9_223_372_036_854_775_807)), - newLiteralEval(Long(1)), + newLiteralEval(types.Long(9_223_372_036_854_775_807)), + newLiteralEval(types.Long(1)), errOverflow}, {"NegativeOverflow", - newLiteralEval(Long(-9_223_372_036_854_775_808)), - newLiteralEval(Long(-1)), + newLiteralEval(types.Long(-9_223_372_036_854_775_808)), + newLiteralEval(types.Long(-1)), errOverflow}, } for _, tt := range tests { @@ -371,7 +373,7 @@ func TestAddNode(t *testing.T) { t.Parallel() n := newAddEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -380,10 +382,10 @@ func TestSubtractNode(t *testing.T) { t.Parallel() t.Run("Basic", func(t *testing.T) { t.Parallel() - n := newSubtractEval(newLiteralEval(Long(1)), newLiteralEval(Long(2))) + n := newSubtractEval(newLiteralEval(types.Long(1)), newLiteralEval(types.Long(2))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertLongValue(t, v, -1) + testutil.OK(t, err) + types.AssertLongValue(t, v, -1) }) tests := []struct { @@ -391,17 +393,17 @@ func TestSubtractNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(0)), errType}, - {"RhsError", newLiteralEval(Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Long(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, {"PositiveOverflow", - newLiteralEval(Long(9_223_372_036_854_775_807)), - newLiteralEval(Long(-1)), + newLiteralEval(types.Long(9_223_372_036_854_775_807)), + newLiteralEval(types.Long(-1)), errOverflow}, {"NegativeOverflow", - newLiteralEval(Long(-9_223_372_036_854_775_808)), - newLiteralEval(Long(1)), + newLiteralEval(types.Long(-9_223_372_036_854_775_808)), + newLiteralEval(types.Long(1)), errOverflow}, } for _, tt := range tests { @@ -410,7 +412,7 @@ func TestSubtractNode(t *testing.T) { t.Parallel() n := newSubtractEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -419,10 +421,10 @@ func TestMultiplyNode(t *testing.T) { t.Parallel() t.Run("Basic", func(t *testing.T) { t.Parallel() - n := newMultiplyEval(newLiteralEval(Long(-3)), newLiteralEval(Long(2))) + n := newMultiplyEval(newLiteralEval(types.Long(-3)), newLiteralEval(types.Long(2))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertLongValue(t, v, -6) + testutil.OK(t, err) + types.AssertLongValue(t, v, -6) }) tests := []struct { @@ -430,17 +432,17 @@ func TestMultiplyNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(0)), errType}, - {"RhsError", newLiteralEval(Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Long(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, {"PositiveOverflow", - newLiteralEval(Long(9_223_372_036_854_775_807)), - newLiteralEval(Long(2)), + newLiteralEval(types.Long(9_223_372_036_854_775_807)), + newLiteralEval(types.Long(2)), errOverflow}, {"NegativeOverflow", - newLiteralEval(Long(-9_223_372_036_854_775_808)), - newLiteralEval(Long(2)), + newLiteralEval(types.Long(-9_223_372_036_854_775_808)), + newLiteralEval(types.Long(2)), errOverflow}, } for _, tt := range tests { @@ -449,7 +451,7 @@ func TestMultiplyNode(t *testing.T) { t.Parallel() n := newMultiplyEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -458,10 +460,10 @@ func TestNegateNode(t *testing.T) { t.Parallel() t.Run("Basic", func(t *testing.T) { t.Parallel() - n := newNegateEval(newLiteralEval(Long(-3))) + n := newNegateEval(newLiteralEval(types.Long(-3))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertLongValue(t, v, 3) + testutil.OK(t, err) + types.AssertLongValue(t, v, 3) }) tests := []struct { @@ -470,8 +472,8 @@ func TestNegateNode(t *testing.T) { err error }{ {"Error", newErrorEval(errTest), errTest}, - {"TypeError", newLiteralEval(Boolean(true)), errType}, - {"Overflow", newLiteralEval(Long(-9_223_372_036_854_775_808)), errOverflow}, + {"TypeError", newLiteralEval(types.Boolean(true)), types.ErrType}, + {"Overflow", newLiteralEval(types.Long(-9_223_372_036_854_775_808)), errOverflow}, } for _, tt := range tests { tt := tt @@ -479,7 +481,7 @@ func TestNegateNode(t *testing.T) { t.Parallel() n := newNegateEval(tt.arg) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -506,10 +508,10 @@ func TestLongLessThanNode(t *testing.T) { t.Run(fmt.Sprintf("%v<%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() n := newLongLessThanEval( - newLiteralEval(Long(tt.lhs)), newLiteralEval(Long(tt.rhs))) + newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -519,10 +521,10 @@ func TestLongLessThanNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(0)), errType}, - {"RhsError", newLiteralEval(Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Long(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -530,7 +532,7 @@ func TestLongLessThanNode(t *testing.T) { t.Parallel() n := newLongLessThanEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -558,10 +560,10 @@ func TestLongLessThanOrEqualNode(t *testing.T) { t.Run(fmt.Sprintf("%v<=%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() n := newLongLessThanOrEqualEval( - newLiteralEval(Long(tt.lhs)), newLiteralEval(Long(tt.rhs))) + newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -571,10 +573,10 @@ func TestLongLessThanOrEqualNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(0)), errType}, - {"RhsError", newLiteralEval(Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Long(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -582,7 +584,7 @@ func TestLongLessThanOrEqualNode(t *testing.T) { t.Parallel() n := newLongLessThanOrEqualEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -610,10 +612,10 @@ func TestLongGreaterThanNode(t *testing.T) { t.Run(fmt.Sprintf("%v>%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() n := newLongGreaterThanEval( - newLiteralEval(Long(tt.lhs)), newLiteralEval(Long(tt.rhs))) + newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -623,10 +625,10 @@ func TestLongGreaterThanNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(0)), errType}, - {"RhsError", newLiteralEval(Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Long(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -634,7 +636,7 @@ func TestLongGreaterThanNode(t *testing.T) { t.Parallel() n := newLongGreaterThanEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -662,10 +664,10 @@ func TestLongGreaterThanOrEqualNode(t *testing.T) { t.Run(fmt.Sprintf("%v>=%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() n := newLongGreaterThanOrEqualEval( - newLiteralEval(Long(tt.lhs)), newLiteralEval(Long(tt.rhs))) + newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -675,10 +677,10 @@ func TestLongGreaterThanOrEqualNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(0)), errType}, - {"RhsError", newLiteralEval(Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Long(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -686,7 +688,7 @@ func TestLongGreaterThanOrEqualNode(t *testing.T) { t.Parallel() n := newLongGreaterThanOrEqualEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -713,16 +715,16 @@ func TestDecimalLessThanNode(t *testing.T) { tt := tt t.Run(fmt.Sprintf("%s<%s", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() - lhsd, err := ParseDecimal(tt.lhs) - testutilOK(t, err) + lhsd, err := types.ParseDecimal(tt.lhs) + testutil.OK(t, err) lhsv := lhsd - rhsd, err := ParseDecimal(tt.rhs) - testutilOK(t, err) + rhsd, err := types.ParseDecimal(tt.rhs) + testutil.OK(t, err) rhsv := rhsd n := newDecimalLessThanEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -732,10 +734,10 @@ func TestDecimalLessThanNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Decimal(0)), errType}, - {"RhsError", newLiteralEval(Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Decimal(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Decimal(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -743,7 +745,7 @@ func TestDecimalLessThanNode(t *testing.T) { t.Parallel() n := newDecimalLessThanEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -770,16 +772,16 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { tt := tt t.Run(fmt.Sprintf("%s<=%s", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() - lhsd, err := ParseDecimal(tt.lhs) - testutilOK(t, err) + lhsd, err := types.ParseDecimal(tt.lhs) + testutil.OK(t, err) lhsv := lhsd - rhsd, err := ParseDecimal(tt.rhs) - testutilOK(t, err) + rhsd, err := types.ParseDecimal(tt.rhs) + testutil.OK(t, err) rhsv := rhsd n := newDecimalLessThanOrEqualEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -789,10 +791,10 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Decimal(0)), errType}, - {"RhsError", newLiteralEval(Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Decimal(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Decimal(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -800,7 +802,7 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { t.Parallel() n := newDecimalLessThanOrEqualEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -827,16 +829,16 @@ func TestDecimalGreaterThanNode(t *testing.T) { tt := tt t.Run(fmt.Sprintf("%s>%s", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() - lhsd, err := ParseDecimal(tt.lhs) - testutilOK(t, err) + lhsd, err := types.ParseDecimal(tt.lhs) + testutil.OK(t, err) lhsv := lhsd - rhsd, err := ParseDecimal(tt.rhs) - testutilOK(t, err) + rhsd, err := types.ParseDecimal(tt.rhs) + testutil.OK(t, err) rhsv := rhsd n := newDecimalGreaterThanEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -846,10 +848,10 @@ func TestDecimalGreaterThanNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Decimal(0)), errType}, - {"RhsError", newLiteralEval(Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Decimal(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Decimal(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -857,7 +859,7 @@ func TestDecimalGreaterThanNode(t *testing.T) { t.Parallel() n := newDecimalGreaterThanEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -884,16 +886,16 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { tt := tt t.Run(fmt.Sprintf("%s>=%s", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() - lhsd, err := ParseDecimal(tt.lhs) - testutilOK(t, err) + lhsd, err := types.ParseDecimal(tt.lhs) + testutil.OK(t, err) lhsv := lhsd - rhsd, err := ParseDecimal(tt.rhs) - testutilOK(t, err) + rhsd, err := types.ParseDecimal(tt.rhs) + testutil.OK(t, err) rhsv := rhsd n := newDecimalGreaterThanOrEqualEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -903,10 +905,10 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Decimal(0)), errType}, - {"RhsError", newLiteralEval(Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Decimal(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Decimal(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -914,7 +916,7 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { t.Parallel() n := newDecimalGreaterThanOrEqualEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -925,19 +927,19 @@ func TestIfThenElseNode(t *testing.T) { tests := []struct { name string if_, then, else_ evaler - result Value + result types.Value err error }{ - {"Then", newLiteralEval(Boolean(true)), newLiteralEval(Long(42)), - newLiteralEval(Long(-1)), Long(42), + {"Then", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(42)), + newLiteralEval(types.Long(-1)), types.Long(42), nil}, - {"Else", newLiteralEval(Boolean(false)), newLiteralEval(Long(-1)), - newLiteralEval(Long(42)), Long(42), + {"Else", newLiteralEval(types.Boolean(false)), newLiteralEval(types.Long(-1)), + newLiteralEval(types.Long(42)), types.Long(42), nil}, - {"Err", newErrorEval(errTest), newLiteralEval(zeroValue()), newLiteralEval(zeroValue()), zeroValue(), + {"Err", newErrorEval(errTest), newLiteralEval(types.ZeroValue()), newLiteralEval(types.ZeroValue()), types.ZeroValue(), errTest}, - {"ErrType", newLiteralEval(Long(123)), newLiteralEval(zeroValue()), newLiteralEval(zeroValue()), zeroValue(), - errType}, + {"ErrType", newLiteralEval(types.Long(123)), newLiteralEval(types.ZeroValue()), newLiteralEval(types.ZeroValue()), types.ZeroValue(), + types.ErrType}, } for _, tt := range tests { tt := tt @@ -945,8 +947,8 @@ func TestIfThenElseNode(t *testing.T) { t.Parallel() n := newIfThenElseEval(tt.if_, tt.then, tt.else_) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - testutilEquals(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + testutil.Equals(t, v, tt.result) }) } } @@ -956,14 +958,14 @@ func TestEqualNode(t *testing.T) { tests := []struct { name string lhs, rhs evaler - result Value + result types.Value err error }{ - {"equals", newLiteralEval(Long(42)), newLiteralEval(Long(42)), Boolean(true), nil}, - {"notEquals", newLiteralEval(Long(42)), newLiteralEval(Long(1234)), Boolean(false), nil}, - {"leftErr", newErrorEval(errTest), newLiteralEval(zeroValue()), zeroValue(), errTest}, - {"rightErr", newLiteralEval(zeroValue()), newErrorEval(errTest), zeroValue(), errTest}, - {"typesNotEqual", newLiteralEval(Long(1)), newLiteralEval(Boolean(true)), Boolean(false), nil}, + {"equals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(42)), types.Boolean(true), nil}, + {"notEquals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(1234)), types.Boolean(false), nil}, + {"leftErr", newErrorEval(errTest), newLiteralEval(types.ZeroValue()), types.ZeroValue(), errTest}, + {"rightErr", newLiteralEval(types.ZeroValue()), newErrorEval(errTest), types.ZeroValue(), errTest}, + {"typesNotEqual", newLiteralEval(types.Long(1)), newLiteralEval(types.Boolean(true)), types.Boolean(false), nil}, } for _, tt := range tests { tt := tt @@ -971,8 +973,8 @@ func TestEqualNode(t *testing.T) { t.Parallel() n := newEqualEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -982,14 +984,14 @@ func TestNotEqualNode(t *testing.T) { tests := []struct { name string lhs, rhs evaler - result Value + result types.Value err error }{ - {"equals", newLiteralEval(Long(42)), newLiteralEval(Long(42)), Boolean(false), nil}, - {"notEquals", newLiteralEval(Long(42)), newLiteralEval(Long(1234)), Boolean(true), nil}, - {"leftErr", newErrorEval(errTest), newLiteralEval(zeroValue()), zeroValue(), errTest}, - {"rightErr", newLiteralEval(zeroValue()), newErrorEval(errTest), zeroValue(), errTest}, - {"typesNotEqual", newLiteralEval(Long(1)), newLiteralEval(Boolean(true)), Boolean(true), nil}, + {"equals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(42)), types.Boolean(false), nil}, + {"notEquals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(1234)), types.Boolean(true), nil}, + {"leftErr", newErrorEval(errTest), newLiteralEval(types.ZeroValue()), types.ZeroValue(), errTest}, + {"rightErr", newLiteralEval(types.ZeroValue()), newErrorEval(errTest), types.ZeroValue(), errTest}, + {"typesNotEqual", newLiteralEval(types.Long(1)), newLiteralEval(types.Boolean(true)), types.Boolean(true), nil}, } for _, tt := range tests { tt := tt @@ -997,8 +999,8 @@ func TestNotEqualNode(t *testing.T) { t.Parallel() n := newNotEqualEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -1008,27 +1010,27 @@ func TestSetLiteralNode(t *testing.T) { tests := []struct { name string elems []evaler - result Value + result types.Value err error }{ - {"empty", []evaler{}, Set{}, nil}, - {"errorNode", []evaler{newErrorEval(errTest)}, zeroValue(), errTest}, + {"empty", []evaler{}, types.Set{}, nil}, + {"errorNode", []evaler{newErrorEval(errTest)}, types.ZeroValue(), errTest}, {"nested", []evaler{ - newLiteralEval(Boolean(true)), - newLiteralEval(Set{ - Boolean(false), - Long(1), + newLiteralEval(types.Boolean(true)), + newLiteralEval(types.Set{ + types.Boolean(false), + types.Long(1), }), - newLiteralEval(Long(10)), + newLiteralEval(types.Long(10)), }, - Set{ - Boolean(true), - Set{ - Boolean(false), - Long(1), + types.Set{ + types.Boolean(true), + types.Set{ + types.Boolean(false), + types.Long(1), }, - Long(10), + types.Long(10), }, nil}, } @@ -1038,8 +1040,8 @@ func TestSetLiteralNode(t *testing.T) { t.Parallel() n := newSetLiteralEval(tt.elems) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -1052,9 +1054,9 @@ func TestContainsNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(0)), errType}, - {"RhsError", newLiteralEval(Set{}), newErrorEval(errTest), errTest}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Set{}), newErrorEval(errTest), errTest}, } for _, tt := range tests { tt := tt @@ -1062,28 +1064,28 @@ func TestContainsNode(t *testing.T) { t.Parallel() n := newContainsEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertZeroValue(t, v) + testutil.AssertError(t, err, tt.err) + types.AssertZeroValue(t, v) }) } } { - empty := Set{} - trueAndOne := Set{Boolean(true), Long(1)} - nested := Set{trueAndOne, Boolean(false), Long(2)} + empty := types.Set{} + trueAndOne := types.Set{types.Boolean(true), types.Long(1)} + nested := types.Set{trueAndOne, types.Boolean(false), types.Long(2)} tests := []struct { name string lhs, rhs evaler result bool }{ - {"empty", newLiteralEval(empty), newLiteralEval(Boolean(true)), false}, - {"trueAndOneContainsTrue", newLiteralEval(trueAndOne), newLiteralEval(Boolean(true)), true}, - {"trueAndOneContainsOne", newLiteralEval(trueAndOne), newLiteralEval(Long(1)), true}, - {"trueAndOneDoesNotContainTwo", newLiteralEval(trueAndOne), newLiteralEval(Long(2)), false}, - {"nestedContainsFalse", newLiteralEval(nested), newLiteralEval(Boolean(false)), true}, + {"empty", newLiteralEval(empty), newLiteralEval(types.Boolean(true)), false}, + {"trueAndOneContainsTrue", newLiteralEval(trueAndOne), newLiteralEval(types.Boolean(true)), true}, + {"trueAndOneContainsOne", newLiteralEval(trueAndOne), newLiteralEval(types.Long(1)), true}, + {"trueAndOneDoesNotContainTwo", newLiteralEval(trueAndOne), newLiteralEval(types.Long(2)), false}, + {"nestedContainsFalse", newLiteralEval(nested), newLiteralEval(types.Boolean(false)), true}, {"nestedContainsSet", newLiteralEval(nested), newLiteralEval(trueAndOne), true}, - {"nestedDoesNotContainTrue", newLiteralEval(nested), newLiteralEval(Boolean(true)), false}, + {"nestedDoesNotContainTrue", newLiteralEval(nested), newLiteralEval(types.Boolean(true)), false}, } for _, tt := range tests { tt := tt @@ -1091,8 +1093,8 @@ func TestContainsNode(t *testing.T) { t.Parallel() n := newContainsEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -1106,10 +1108,10 @@ func TestContainsAllNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Set{}), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Set{}), errType}, - {"RhsError", newLiteralEval(Set{}), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Set{}), newLiteralEval(Long(0)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Set{}), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Set{}), types.ErrType}, + {"RhsError", newLiteralEval(types.Set{}), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Set{}), newLiteralEval(types.Long(0)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -1117,16 +1119,16 @@ func TestContainsAllNode(t *testing.T) { t.Parallel() n := newContainsAllEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertZeroValue(t, v) + testutil.AssertError(t, err, tt.err) + types.AssertZeroValue(t, v) }) } } { - empty := Set{} - trueOnly := Set{Boolean(true)} - trueAndOne := Set{Boolean(true), Long(1)} - nested := Set{trueAndOne, Boolean(false), Long(2)} + empty := types.Set{} + trueOnly := types.Set{types.Boolean(true)} + trueAndOne := types.Set{types.Boolean(true), types.Long(1)} + nested := types.Set{trueAndOne, types.Boolean(false), types.Long(2)} tests := []struct { name string @@ -1145,8 +1147,8 @@ func TestContainsAllNode(t *testing.T) { t.Parallel() n := newContainsAllEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -1160,10 +1162,10 @@ func TestContainsAnyNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Set{}), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Set{}), errType}, - {"RhsError", newLiteralEval(Set{}), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Set{}), newLiteralEval(Long(0)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Set{}), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Set{}), types.ErrType}, + {"RhsError", newLiteralEval(types.Set{}), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Set{}), newLiteralEval(types.Long(0)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -1171,17 +1173,17 @@ func TestContainsAnyNode(t *testing.T) { t.Parallel() n := newContainsAnyEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertZeroValue(t, v) + testutil.AssertError(t, err, tt.err) + types.AssertZeroValue(t, v) }) } } { - empty := Set{} - trueOnly := Set{Boolean(true)} - trueAndOne := Set{Boolean(true), Long(1)} - trueAndTwo := Set{Boolean(true), Long(2)} - nested := Set{trueAndOne, Boolean(false), Long(2)} + empty := types.Set{} + trueOnly := types.Set{types.Boolean(true)} + trueAndOne := types.Set{types.Boolean(true), types.Long(1)} + trueAndTwo := types.Set{types.Boolean(true), types.Long(2)} + nested := types.Set{trueAndOne, types.Boolean(false), types.Long(2)} tests := []struct { name string @@ -1202,8 +1204,8 @@ func TestContainsAnyNode(t *testing.T) { t.Parallel() n := newContainsAnyEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -1214,18 +1216,18 @@ func TestRecordLiteralNode(t *testing.T) { tests := []struct { name string elems map[string]evaler - result Value + result types.Value err error }{ - {"empty", map[string]evaler{}, Record{}, nil}, - {"errorNode", map[string]evaler{"foo": newErrorEval(errTest)}, zeroValue(), errTest}, + {"empty", map[string]evaler{}, types.Record{}, nil}, + {"errorNode", map[string]evaler{"foo": newErrorEval(errTest)}, types.ZeroValue(), errTest}, {"ok", map[string]evaler{ - "foo": newLiteralEval(Boolean(true)), - "bar": newLiteralEval(String("baz")), - }, Record{ - "foo": Boolean(true), - "bar": String("baz"), + "foo": newLiteralEval(types.Boolean(true)), + "bar": newLiteralEval(types.String("baz")), + }, types.Record{ + "foo": types.Boolean(true), + "bar": types.String("baz"), }, nil}, } for _, tt := range tests { @@ -1234,8 +1236,8 @@ func TestRecordLiteralNode(t *testing.T) { t.Parallel() n := newRecordLiteralEval(tt.elems) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -1246,35 +1248,35 @@ func TestAttributeAccessNode(t *testing.T) { name string object evaler attribute string - result Value + result types.Value err error }{ - {"RecordError", newErrorEval(errTest), "foo", zeroValue(), errTest}, - {"RecordTypeError", newLiteralEval(Boolean(true)), "foo", zeroValue(), errType}, + {"RecordError", newErrorEval(errTest), "foo", types.ZeroValue(), errTest}, + {"RecordTypeError", newLiteralEval(types.Boolean(true)), "foo", types.ZeroValue(), types.ErrType}, {"UnknownAttribute", - newLiteralEval(Record{}), + newLiteralEval(types.Record{}), "foo", - zeroValue(), + types.ZeroValue(), errAttributeAccess}, {"KnownAttribute", - newLiteralEval(Record{"foo": Long(42)}), + newLiteralEval(types.Record{"foo": types.Long(42)}), "foo", - Long(42), + types.Long(42), nil}, {"KnownAttributeOnEntity", - newLiteralEval(EntityUID{"knownType", "knownID"}), + newLiteralEval(types.EntityUID{"knownType", "knownID"}), "knownAttr", - Long(42), + types.Long(42), nil}, {"UnknownEntity", - newLiteralEval(EntityUID{"unknownType", "unknownID"}), + newLiteralEval(types.EntityUID{"unknownType", "unknownID"}), "unknownAttr", - zeroValue(), + types.ZeroValue(), errEntityNotExist}, {"UnspecifiedEntity", - newLiteralEval(EntityUID{"", ""}), + newLiteralEval(types.EntityUID{"", ""}), "knownAttr", - zeroValue(), + types.ZeroValue(), errUnspecifiedEntity}, } for _, tt := range tests { @@ -1285,13 +1287,13 @@ func TestAttributeAccessNode(t *testing.T) { v, err := n.Eval(&evalContext{ Entities: entitiesFromSlice([]Entity{ { - UID: NewEntityUID("knownType", "knownID"), - Attributes: Record{"knownAttr": Long(42)}, + UID: types.NewEntityUID("knownType", "knownID"), + Attributes: types.Record{"knownAttr": types.Long(42)}, }, }), }) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -1302,35 +1304,35 @@ func TestHasNode(t *testing.T) { name string record evaler attribute string - result Value + result types.Value err error }{ - {"RecordError", newErrorEval(errTest), "foo", zeroValue(), errTest}, - {"RecordTypeError", newLiteralEval(Boolean(true)), "foo", zeroValue(), errType}, + {"RecordError", newErrorEval(errTest), "foo", types.ZeroValue(), errTest}, + {"RecordTypeError", newLiteralEval(types.Boolean(true)), "foo", types.ZeroValue(), types.ErrType}, {"UnknownAttribute", - newLiteralEval(Record{}), + newLiteralEval(types.Record{}), "foo", - Boolean(false), + types.Boolean(false), nil}, {"KnownAttribute", - newLiteralEval(Record{"foo": Long(42)}), + newLiteralEval(types.Record{"foo": types.Long(42)}), "foo", - Boolean(true), + types.Boolean(true), nil}, {"KnownAttributeOnEntity", - newLiteralEval(EntityUID{"knownType", "knownID"}), + newLiteralEval(types.EntityUID{"knownType", "knownID"}), "knownAttr", - Boolean(true), + types.Boolean(true), nil}, {"UnknownAttributeOnEntity", - newLiteralEval(EntityUID{"knownType", "knownID"}), + newLiteralEval(types.EntityUID{"knownType", "knownID"}), "unknownAttr", - Boolean(false), + types.Boolean(false), nil}, {"UnknownEntity", - newLiteralEval(EntityUID{"unknownType", "unknownID"}), + newLiteralEval(types.EntityUID{"unknownType", "unknownID"}), "unknownAttr", - Boolean(false), + types.Boolean(false), nil}, } for _, tt := range tests { @@ -1341,13 +1343,13 @@ func TestHasNode(t *testing.T) { v, err := n.Eval(&evalContext{ Entities: entitiesFromSlice([]Entity{ { - UID: NewEntityUID("knownType", "knownID"), - Attributes: Record{"knownAttr": Long(42)}, + UID: types.NewEntityUID("knownType", "knownID"), + Attributes: types.Record{"knownAttr": types.Long(42)}, }, }), }) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -1358,50 +1360,50 @@ func TestLikeNode(t *testing.T) { name string str evaler pattern string - result Value + result types.Value err error }{ - {"leftError", newErrorEval(errTest), `"foo"`, zeroValue(), errTest}, - {"leftTypeError", newLiteralEval(Boolean(true)), `"foo"`, zeroValue(), errType}, - {"noMatch", newLiteralEval(String("test")), `"zebra"`, Boolean(false), nil}, - {"match", newLiteralEval(String("test")), `"*es*"`, Boolean(true), nil}, - - {"case-1", newLiteralEval(String("eggs")), `"ham*"`, Boolean(false), nil}, - {"case-2", newLiteralEval(String("eggs")), `"*ham"`, Boolean(false), nil}, - {"case-3", newLiteralEval(String("eggs")), `"*ham*"`, Boolean(false), nil}, - {"case-4", newLiteralEval(String("ham and eggs")), `"ham*"`, Boolean(true), nil}, - {"case-5", newLiteralEval(String("ham and eggs")), `"*ham"`, Boolean(false), nil}, - {"case-6", newLiteralEval(String("ham and eggs")), `"*ham*"`, Boolean(true), nil}, - {"case-7", newLiteralEval(String("ham and eggs")), `"*h*a*m*"`, Boolean(true), nil}, - {"case-8", newLiteralEval(String("eggs and ham")), `"ham*"`, Boolean(false), nil}, - {"case-9", newLiteralEval(String("eggs and ham")), `"*ham"`, Boolean(true), nil}, - {"case-10", newLiteralEval(String("eggs, ham, and spinach")), `"ham*"`, Boolean(false), nil}, - {"case-11", newLiteralEval(String("eggs, ham, and spinach")), `"*ham"`, Boolean(false), nil}, - {"case-12", newLiteralEval(String("eggs, ham, and spinach")), `"*ham*"`, Boolean(true), nil}, - {"case-13", newLiteralEval(String("Gotham")), `"ham*"`, Boolean(false), nil}, - {"case-14", newLiteralEval(String("Gotham")), `"*ham"`, Boolean(true), nil}, - {"case-15", newLiteralEval(String("ham")), `"ham"`, Boolean(true), nil}, - {"case-16", newLiteralEval(String("ham")), `"ham*"`, Boolean(true), nil}, - {"case-17", newLiteralEval(String("ham")), `"*ham"`, Boolean(true), nil}, - {"case-18", newLiteralEval(String("ham")), `"*h*a*m*"`, Boolean(true), nil}, - {"case-19", newLiteralEval(String("ham and ham")), `"ham*"`, Boolean(true), nil}, - {"case-20", newLiteralEval(String("ham and ham")), `"*ham"`, Boolean(true), nil}, - {"case-21", newLiteralEval(String("ham")), `"*ham and eggs*"`, Boolean(false), nil}, - {"case-22", newLiteralEval(String("\\afterslash")), `"\\*"`, Boolean(true), nil}, - {"case-23", newLiteralEval(String("string\\with\\backslashes")), `"string\\with\\backslashes"`, Boolean(true), nil}, - {"case-24", newLiteralEval(String("string\\with\\backslashes")), `"string*with*backslashes"`, Boolean(true), nil}, - {"case-25", newLiteralEval(String("string*with*stars")), `"string\*with\*stars"`, Boolean(true), nil}, + {"leftError", newErrorEval(errTest), `"foo"`, types.ZeroValue(), errTest}, + {"leftTypeError", newLiteralEval(types.Boolean(true)), `"foo"`, types.ZeroValue(), types.ErrType}, + {"noMatch", newLiteralEval(types.String("test")), `"zebra"`, types.Boolean(false), nil}, + {"match", newLiteralEval(types.String("test")), `"*es*"`, types.Boolean(true), nil}, + + {"case-1", newLiteralEval(types.String("eggs")), `"ham*"`, types.Boolean(false), nil}, + {"case-2", newLiteralEval(types.String("eggs")), `"*ham"`, types.Boolean(false), nil}, + {"case-3", newLiteralEval(types.String("eggs")), `"*ham*"`, types.Boolean(false), nil}, + {"case-4", newLiteralEval(types.String("ham and eggs")), `"ham*"`, types.Boolean(true), nil}, + {"case-5", newLiteralEval(types.String("ham and eggs")), `"*ham"`, types.Boolean(false), nil}, + {"case-6", newLiteralEval(types.String("ham and eggs")), `"*ham*"`, types.Boolean(true), nil}, + {"case-7", newLiteralEval(types.String("ham and eggs")), `"*h*a*m*"`, types.Boolean(true), nil}, + {"case-8", newLiteralEval(types.String("eggs and ham")), `"ham*"`, types.Boolean(false), nil}, + {"case-9", newLiteralEval(types.String("eggs and ham")), `"*ham"`, types.Boolean(true), nil}, + {"case-10", newLiteralEval(types.String("eggs, ham, and spinach")), `"ham*"`, types.Boolean(false), nil}, + {"case-11", newLiteralEval(types.String("eggs, ham, and spinach")), `"*ham"`, types.Boolean(false), nil}, + {"case-12", newLiteralEval(types.String("eggs, ham, and spinach")), `"*ham*"`, types.Boolean(true), nil}, + {"case-13", newLiteralEval(types.String("Gotham")), `"ham*"`, types.Boolean(false), nil}, + {"case-14", newLiteralEval(types.String("Gotham")), `"*ham"`, types.Boolean(true), nil}, + {"case-15", newLiteralEval(types.String("ham")), `"ham"`, types.Boolean(true), nil}, + {"case-16", newLiteralEval(types.String("ham")), `"ham*"`, types.Boolean(true), nil}, + {"case-17", newLiteralEval(types.String("ham")), `"*ham"`, types.Boolean(true), nil}, + {"case-18", newLiteralEval(types.String("ham")), `"*h*a*m*"`, types.Boolean(true), nil}, + {"case-19", newLiteralEval(types.String("ham and ham")), `"ham*"`, types.Boolean(true), nil}, + {"case-20", newLiteralEval(types.String("ham and ham")), `"*ham"`, types.Boolean(true), nil}, + {"case-21", newLiteralEval(types.String("ham")), `"*ham and eggs*"`, types.Boolean(false), nil}, + {"case-22", newLiteralEval(types.String("\\afterslash")), `"\\*"`, types.Boolean(true), nil}, + {"case-23", newLiteralEval(types.String("string\\with\\backslashes")), `"string\\with\\backslashes"`, types.Boolean(true), nil}, + {"case-24", newLiteralEval(types.String("string\\with\\backslashes")), `"string*with*backslashes"`, types.Boolean(true), nil}, + {"case-25", newLiteralEval(types.String("string*with*stars")), `"string\*with\*stars"`, types.Boolean(true), nil}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() pat, err := parser.NewPattern(tt.pattern) - testutilOK(t, err) + testutil.OK(t, err) n := newLikeEval(tt.str, pat) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -1412,24 +1414,24 @@ func TestVariableNode(t *testing.T) { name string context evalContext variable variableName - result Value + result types.Value }{ {"principal", - evalContext{Principal: String("foo")}, + evalContext{Principal: types.String("foo")}, variableNamePrincipal, - String("foo")}, + types.String("foo")}, {"action", - evalContext{Action: String("bar")}, + evalContext{Action: types.String("bar")}, variableNameAction, - String("bar")}, + types.String("bar")}, {"resource", - evalContext{Resource: String("baz")}, + evalContext{Resource: types.String("baz")}, variableNameResource, - String("baz")}, + types.String("baz")}, {"context", - evalContext{Context: String("frob")}, + evalContext{Context: types.String("frob")}, variableNameContext, - String("frob")}, + types.String("frob")}, } for _, tt := range tests { tt := tt @@ -1437,8 +1439,8 @@ func TestVariableNode(t *testing.T) { t.Parallel() n := newVariableEval(tt.variable) v, err := n.Eval(&tt.context) - testutilOK(t, err) - assertValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertValue(t, v, tt.result) }) } } @@ -1530,13 +1532,13 @@ func TestEntityIn(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - rhs := map[EntityUID]struct{}{} + rhs := map[types.EntityUID]struct{}{} for _, v := range tt.rhs { rhs[strEnt(v)] = struct{}{} } entities := Entities{} for k, p := range tt.parents { - var ps []EntityUID + var ps []types.EntityUID for _, pp := range p { ps = append(ps, strEnt(pp)) } @@ -1547,7 +1549,7 @@ func TestEntityIn(t *testing.T) { } } res := entityIn(strEnt(tt.lhs), rhs, entities) - testutilEquals(t, res, tt.result) + testutil.Equals(t, res, tt.result) }) } t.Run("exponentialWithoutCaching", func(t *testing.T) { @@ -1557,24 +1559,24 @@ func TestEntityIn(t *testing.T) { entities := Entities{} for i := 0; i < 100; i++ { - p := []EntityUID{ - NewEntityUID(fmt.Sprint(i+1), "1"), - NewEntityUID(fmt.Sprint(i+1), "2"), + p := []types.EntityUID{ + types.NewEntityUID(fmt.Sprint(i+1), "1"), + types.NewEntityUID(fmt.Sprint(i+1), "2"), } - uid1 := NewEntityUID(fmt.Sprint(i), "1") + uid1 := types.NewEntityUID(fmt.Sprint(i), "1") entities[uid1] = Entity{ UID: uid1, Parents: p, } - uid2 := NewEntityUID(fmt.Sprint(i), "2") + uid2 := types.NewEntityUID(fmt.Sprint(i), "2") entities[uid2] = Entity{ UID: uid2, Parents: p, } } - res := entityIn(NewEntityUID("0", "1"), map[EntityUID]struct{}{NewEntityUID("0", "3"): {}}, entities) - testutilEquals(t, res, false) + res := entityIn(types.NewEntityUID("0", "1"), map[types.EntityUID]struct{}{types.NewEntityUID("0", "3"): {}}, entities) + testutil.Equals(t, res, false) }) } @@ -1583,23 +1585,23 @@ func TestIsNode(t *testing.T) { tests := []struct { name string lhs, rhs evaler - result Value + result types.Value err error }{ - {"happyEq", newLiteralEval(NewEntityUID("X", "z")), newLiteralEval(path("X")), Boolean(true), nil}, - {"happyNeq", newLiteralEval(NewEntityUID("X", "z")), newLiteralEval(path("Y")), Boolean(false), nil}, - {"badLhs", newLiteralEval(Long(42)), newLiteralEval(path("X")), zeroValue(), errType}, - {"badRhs", newLiteralEval(NewEntityUID("X", "z")), newLiteralEval(Long(42)), zeroValue(), errType}, - {"errLhs", newErrorEval(errTest), newLiteralEval(path("X")), zeroValue(), errTest}, - {"errRhs", newLiteralEval(NewEntityUID("X", "z")), newErrorEval(errTest), zeroValue(), errTest}, + {"happyEq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Path("X")), types.Boolean(true), nil}, + {"happyNeq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Path("Y")), types.Boolean(false), nil}, + {"badLhs", newLiteralEval(types.Long(42)), newLiteralEval(types.Path("X")), types.ZeroValue(), types.ErrType}, + {"badRhs", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Long(42)), types.ZeroValue(), types.ErrType}, + {"errLhs", newErrorEval(errTest), newLiteralEval(types.Path("X")), types.ZeroValue(), errTest}, + {"errRhs", newLiteralEval(types.NewEntityUID("X", "z")), newErrorEval(errTest), types.ZeroValue(), errTest}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() got, err := newIsEval(tt.lhs, tt.rhs).Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, got, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, got, tt.result) }) } } @@ -1610,89 +1612,89 @@ func TestInNode(t *testing.T) { name string lhs, rhs evaler parents map[string][]string - result Value + result types.Value err error }{ { "LhsError", newErrorEval(errTest), - newLiteralEval(Set{}), + newLiteralEval(types.Set{}), map[string][]string{}, - zeroValue(), + types.ZeroValue(), errTest, }, { "LhsTypeError", - newLiteralEval(String("foo")), - newLiteralEval(Set{}), + newLiteralEval(types.String("foo")), + newLiteralEval(types.Set{}), map[string][]string{}, - zeroValue(), - errType, + types.ZeroValue(), + types.ErrType, }, { "RhsError", - newLiteralEval(EntityUID{"human", "joe"}), + newLiteralEval(types.EntityUID{"human", "joe"}), newErrorEval(errTest), map[string][]string{}, - zeroValue(), + types.ZeroValue(), errTest, }, { "RhsTypeError1", - newLiteralEval(EntityUID{"human", "joe"}), - newLiteralEval(String("foo")), + newLiteralEval(types.EntityUID{"human", "joe"}), + newLiteralEval(types.String("foo")), map[string][]string{}, - zeroValue(), - errType, + types.ZeroValue(), + types.ErrType, }, { "RhsTypeError2", - newLiteralEval(EntityUID{"human", "joe"}), - newLiteralEval(Set{ - String("foo"), + newLiteralEval(types.EntityUID{"human", "joe"}), + newLiteralEval(types.Set{ + types.String("foo"), }), map[string][]string{}, - zeroValue(), - errType, + types.ZeroValue(), + types.ErrType, }, { "Reflexive1", - newLiteralEval(EntityUID{"human", "joe"}), - newLiteralEval(EntityUID{"human", "joe"}), + newLiteralEval(types.EntityUID{"human", "joe"}), + newLiteralEval(types.EntityUID{"human", "joe"}), map[string][]string{}, - Boolean(true), + types.Boolean(true), nil, }, { "Reflexive2", - newLiteralEval(EntityUID{"human", "joe"}), - newLiteralEval(Set{ - EntityUID{"human", "joe"}, + newLiteralEval(types.EntityUID{"human", "joe"}), + newLiteralEval(types.Set{ + types.EntityUID{"human", "joe"}, }), map[string][]string{}, - Boolean(true), + types.Boolean(true), nil, }, { "BasicTrue", - newLiteralEval(EntityUID{"human", "joe"}), - newLiteralEval(EntityUID{"kingdom", "animal"}), + newLiteralEval(types.EntityUID{"human", "joe"}), + newLiteralEval(types.EntityUID{"kingdom", "animal"}), map[string][]string{ `human::"joe"`: {`species::"human"`}, `species::"human"`: {`kingdom::"animal"`}, }, - Boolean(true), + types.Boolean(true), nil, }, { "BasicFalse", - newLiteralEval(EntityUID{"human", "joe"}), - newLiteralEval(EntityUID{"kingdom", "plant"}), + newLiteralEval(types.EntityUID{"human", "joe"}), + newLiteralEval(types.EntityUID{"kingdom", "plant"}), map[string][]string{ `human::"joe"`: {`species::"human"`}, `species::"human"`: {`kingdom::"animal"`}, }, - Boolean(false), + types.Boolean(false), nil, }, } @@ -1703,7 +1705,7 @@ func TestInNode(t *testing.T) { n := newInEval(tt.lhs, tt.rhs) entities := Entities{} for k, p := range tt.parents { - var ps []EntityUID + var ps []types.EntityUID for _, pp := range p { ps = append(ps, strEnt(pp)) } @@ -1715,8 +1717,8 @@ func TestInNode(t *testing.T) { } evalContext := evalContext{Entities: entities} v, err := n.Eval(&evalContext) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -1726,13 +1728,13 @@ func TestDecimalLiteralNode(t *testing.T) { tests := []struct { name string arg evaler - result Value + result types.Value err error }{ - {"Error", newErrorEval(errTest), zeroValue(), errTest}, - {"TypeError", newLiteralEval(Long(1)), zeroValue(), errType}, - {"DecimalError", newLiteralEval(String("frob")), zeroValue(), errDecimal}, - {"Success", newLiteralEval(String("1.0")), Decimal(10000), nil}, + {"Error", newErrorEval(errTest), types.ZeroValue(), errTest}, + {"TypeError", newLiteralEval(types.Long(1)), types.ZeroValue(), types.ErrType}, + {"DecimalError", newLiteralEval(types.String("frob")), types.ZeroValue(), types.ErrDecimal}, + {"Success", newLiteralEval(types.String("1.0")), types.Decimal(10000), nil}, } for _, tt := range tests { tt := tt @@ -1740,26 +1742,26 @@ func TestDecimalLiteralNode(t *testing.T) { t.Parallel() n := newDecimalLiteralEval(tt.arg) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } func TestIPLiteralNode(t *testing.T) { t.Parallel() - ipv6Loopback, err := ParseIPAddr("::1") - testutilOK(t, err) + ipv6Loopback, err := types.ParseIPAddr("::1") + testutil.OK(t, err) tests := []struct { name string arg evaler - result Value + result types.Value err error }{ - {"Error", newErrorEval(errTest), zeroValue(), errTest}, - {"TypeError", newLiteralEval(Long(1)), zeroValue(), errType}, - {"IPError", newLiteralEval(String("not-an-IP-address")), zeroValue(), errIP}, - {"Success", newLiteralEval(String("::1/128")), ipv6Loopback, nil}, + {"Error", newErrorEval(errTest), types.ZeroValue(), errTest}, + {"TypeError", newLiteralEval(types.Long(1)), types.ZeroValue(), types.ErrType}, + {"IPError", newLiteralEval(types.String("not-an-IP-address")), types.ZeroValue(), types.ErrIP}, + {"Success", newLiteralEval(types.String("::1/128")), ipv6Loopback, nil}, } for _, tt := range tests { tt := tt @@ -1767,37 +1769,37 @@ func TestIPLiteralNode(t *testing.T) { t.Parallel() n := newIPLiteralEval(tt.arg) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } func TestIPTestNode(t *testing.T) { t.Parallel() - ipv4Loopback, err := ParseIPAddr("127.0.0.1") - testutilOK(t, err) - ipv6Loopback, err := ParseIPAddr("::1") - testutilOK(t, err) - ipv4Multicast, err := ParseIPAddr("224.0.0.1") - testutilOK(t, err) + ipv4Loopback, err := types.ParseIPAddr("127.0.0.1") + testutil.OK(t, err) + ipv6Loopback, err := types.ParseIPAddr("::1") + testutil.OK(t, err) + ipv4Multicast, err := types.ParseIPAddr("224.0.0.1") + testutil.OK(t, err) tests := []struct { name string lhs evaler rhs ipTestType - result Value + result types.Value err error }{ - {"Error", newErrorEval(errTest), ipTestIPv4, zeroValue(), errTest}, - {"TypeError", newLiteralEval(Long(1)), ipTestIPv4, zeroValue(), errType}, - {"IPv4True", newLiteralEval(ipv4Loopback), ipTestIPv4, Boolean(true), nil}, - {"IPv4False", newLiteralEval(ipv6Loopback), ipTestIPv4, Boolean(false), nil}, - {"IPv6True", newLiteralEval(ipv6Loopback), ipTestIPv6, Boolean(true), nil}, - {"IPv6False", newLiteralEval(ipv4Loopback), ipTestIPv6, Boolean(false), nil}, - {"LoopbackTrue", newLiteralEval(ipv6Loopback), ipTestLoopback, Boolean(true), nil}, - {"LoopbackFalse", newLiteralEval(ipv4Multicast), ipTestLoopback, Boolean(false), nil}, - {"MulticastTrue", newLiteralEval(ipv4Multicast), ipTestMulticast, Boolean(true), nil}, - {"MulticastFalse", newLiteralEval(ipv6Loopback), ipTestMulticast, Boolean(false), nil}, + {"Error", newErrorEval(errTest), ipTestIPv4, types.ZeroValue(), errTest}, + {"TypeError", newLiteralEval(types.Long(1)), ipTestIPv4, types.ZeroValue(), types.ErrType}, + {"IPv4True", newLiteralEval(ipv4Loopback), ipTestIPv4, types.Boolean(true), nil}, + {"IPv4False", newLiteralEval(ipv6Loopback), ipTestIPv4, types.Boolean(false), nil}, + {"IPv6True", newLiteralEval(ipv6Loopback), ipTestIPv6, types.Boolean(true), nil}, + {"IPv6False", newLiteralEval(ipv4Loopback), ipTestIPv6, types.Boolean(false), nil}, + {"LoopbackTrue", newLiteralEval(ipv6Loopback), ipTestLoopback, types.Boolean(true), nil}, + {"LoopbackFalse", newLiteralEval(ipv4Multicast), ipTestLoopback, types.Boolean(false), nil}, + {"MulticastTrue", newLiteralEval(ipv4Multicast), ipTestMulticast, types.Boolean(true), nil}, + {"MulticastFalse", newLiteralEval(ipv6Loopback), ipTestMulticast, types.Boolean(false), nil}, } for _, tt := range tests { tt := tt @@ -1805,37 +1807,37 @@ func TestIPTestNode(t *testing.T) { t.Parallel() n := newIPTestEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } func TestIPIsInRangeNode(t *testing.T) { t.Parallel() - ipv4A, err := ParseIPAddr("1.2.3.4") - testutilOK(t, err) - ipv4B, err := ParseIPAddr("1.2.3.0/24") - testutilOK(t, err) - ipv4C, err := ParseIPAddr("1.2.4.0/24") - testutilOK(t, err) + ipv4A, err := types.ParseIPAddr("1.2.3.4") + testutil.OK(t, err) + ipv4B, err := types.ParseIPAddr("1.2.3.0/24") + testutil.OK(t, err) + ipv4C, err := types.ParseIPAddr("1.2.4.0/24") + testutil.OK(t, err) tests := []struct { name string lhs, rhs evaler - result Value + result types.Value err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(ipv4A), zeroValue(), errTest}, - {"LhsTypeError", newLiteralEval(Long(1)), newLiteralEval(ipv4A), zeroValue(), errType}, - {"RhsError", newLiteralEval(ipv4A), newErrorEval(errTest), zeroValue(), errTest}, - {"RhsTypeError", newLiteralEval(ipv4A), newLiteralEval(Long(1)), zeroValue(), errType}, - {"AA", newLiteralEval(ipv4A), newLiteralEval(ipv4A), Boolean(true), nil}, - {"AB", newLiteralEval(ipv4A), newLiteralEval(ipv4B), Boolean(true), nil}, - {"BA", newLiteralEval(ipv4B), newLiteralEval(ipv4A), Boolean(false), nil}, - {"AC", newLiteralEval(ipv4A), newLiteralEval(ipv4C), Boolean(false), nil}, - {"CA", newLiteralEval(ipv4C), newLiteralEval(ipv4A), Boolean(false), nil}, - {"BC", newLiteralEval(ipv4B), newLiteralEval(ipv4C), Boolean(false), nil}, - {"CB", newLiteralEval(ipv4C), newLiteralEval(ipv4B), Boolean(false), nil}, + {"LhsError", newErrorEval(errTest), newLiteralEval(ipv4A), types.ZeroValue(), errTest}, + {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(ipv4A), types.ZeroValue(), types.ErrType}, + {"RhsError", newLiteralEval(ipv4A), newErrorEval(errTest), types.ZeroValue(), errTest}, + {"RhsTypeError", newLiteralEval(ipv4A), newLiteralEval(types.Long(1)), types.ZeroValue(), types.ErrType}, + {"AA", newLiteralEval(ipv4A), newLiteralEval(ipv4A), types.Boolean(true), nil}, + {"AB", newLiteralEval(ipv4A), newLiteralEval(ipv4B), types.Boolean(true), nil}, + {"BA", newLiteralEval(ipv4B), newLiteralEval(ipv4A), types.Boolean(false), nil}, + {"AC", newLiteralEval(ipv4A), newLiteralEval(ipv4C), types.Boolean(false), nil}, + {"CA", newLiteralEval(ipv4C), newLiteralEval(ipv4A), types.Boolean(false), nil}, + {"BC", newLiteralEval(ipv4B), newLiteralEval(ipv4C), types.Boolean(false), nil}, + {"CB", newLiteralEval(ipv4C), newLiteralEval(ipv4B), types.Boolean(false), nil}, } for _, tt := range tests { tt := tt @@ -1843,8 +1845,8 @@ func TestIPIsInRangeNode(t *testing.T) { t.Parallel() n := newIPIsInRangeEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -1853,27 +1855,27 @@ func TestCedarString(t *testing.T) { t.Parallel() tests := []struct { name string - in Value + in types.Value wantString string wantCedar string }{ - {"string", String("hello"), `hello`, `"hello"`}, - {"number", Long(42), `42`, `42`}, - {"bool", Boolean(true), `true`, `true`}, - {"record", Record{"a": Long(42), "b": Long(43)}, `{"a":42,"b":43}`, `{"a":42,"b":43}`}, - {"set", Set{Long(42), Long(43)}, `[42,43]`, `[42,43]`}, - {"singleIP", IPAddr(netip.MustParsePrefix("192.168.0.42/32")), `192.168.0.42`, `ip("192.168.0.42")`}, - {"ipPrefix", IPAddr(netip.MustParsePrefix("192.168.0.42/24")), `192.168.0.42/24`, `ip("192.168.0.42/24")`}, - {"decimal", Decimal(12345678), `1234.5678`, `decimal("1234.5678")`}, + {"string", types.String("hello"), `hello`, `"hello"`}, + {"number", types.Long(42), `42`, `42`}, + {"bool", types.Boolean(true), `true`, `true`}, + {"record", types.Record{"a": types.Long(42), "b": types.Long(43)}, `{"a":42,"b":43}`, `{"a":42,"b":43}`}, + {"set", types.Set{types.Long(42), types.Long(43)}, `[42,43]`, `[42,43]`}, + {"singleIP", types.IPAddr(netip.MustParsePrefix("192.168.0.42/32")), `192.168.0.42`, `ip("192.168.0.42")`}, + {"ipPrefix", types.IPAddr(netip.MustParsePrefix("192.168.0.42/24")), `192.168.0.42/24`, `ip("192.168.0.42/24")`}, + {"decimal", types.Decimal(12345678), `1234.5678`, `decimal("1234.5678")`}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() gotString := tt.in.String() - testutilEquals(t, gotString, tt.wantString) + testutil.Equals(t, gotString, tt.wantString) gotCedar := tt.in.Cedar() - testutilEquals(t, gotCedar, tt.wantCedar) + testutil.Equals(t, gotCedar, tt.wantCedar) }) } } diff --git a/match_test.go b/match_test.go index 39b710a2..783ed174 100644 --- a/match_test.go +++ b/match_test.go @@ -3,6 +3,7 @@ package cedar import ( "testing" + "github.com/cedar-policy/cedar-go/testutil" "github.com/cedar-policy/cedar-go/x/exp/parser" ) @@ -38,9 +39,9 @@ func TestMatch(t *testing.T) { t.Run(tt.pattern+":"+tt.target, func(t *testing.T) { t.Parallel() pat, err := parser.NewPattern(tt.pattern) - testutilOK(t, err) + testutil.OK(t, err) got := match(pat, tt.target) - testutilEquals(t, got, tt.want) + testutil.Equals(t, got, tt.want) }) } } diff --git a/testutil/testutil.go b/testutil/testutil.go new file mode 100644 index 00000000..6d897b67 --- /dev/null +++ b/testutil/testutil.go @@ -0,0 +1,44 @@ +package testutil + +import ( + "errors" + "reflect" + "testing" +) + +func Equals[T any](t testing.TB, a, b T) { + t.Helper() + if reflect.DeepEqual(a, b) { + return + } + t.Fatalf("got %+v want %+v", a, b) +} + +func FatalIf(t testing.TB, c bool, f string, args ...any) { + t.Helper() + if !c { + return + } + t.Fatalf(f, args...) +} + +func OK(t testing.TB, err error) { + t.Helper() + if err == nil { + return + } + t.Fatalf("got %v want nil", err) +} + +func Error(t testing.TB, err error) { + t.Helper() + if err != nil { + return + } + t.Fatalf("got nil want error") +} + +func AssertError(t *testing.T, got, want error) { + t.Helper() + FatalIf(t, !errors.Is(got, want), "err got %v want %v", got, want) +} diff --git a/testutil_test.go b/testutil_test.go deleted file mode 100644 index 2de50a82..00000000 --- a/testutil_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package cedar - -import ( - "errors" - "fmt" - "reflect" - "testing" -) - -func testutilEquals[T any](t testing.TB, a, b T) { - t.Helper() - if reflect.DeepEqual(a, b) { - return - } - t.Fatalf("got %+v want %+v", a, b) -} - -func testutilFatalIf(t testing.TB, c bool, f string, args ...any) { - t.Helper() - if !c { - return - } - t.Fatalf(f, args...) -} - -func testutilOK(t testing.TB, err error) { - t.Helper() - if err == nil { - return - } - t.Fatalf("got %v want nil", err) -} - -func testutilError(t testing.TB, err error) { - t.Helper() - if err != nil { - return - } - t.Fatalf("got nil want error") -} - -func assertError(t *testing.T, got, want error) { - t.Helper() - testutilFatalIf(t, !errors.Is(got, want), "err got %v want %v", got, want) -} - -func assertValue(t *testing.T, got, want Value) { - t.Helper() - testutilFatalIf( - t, - !((got == zeroValue() && want == zeroValue()) || - (got != zeroValue() && want != zeroValue() && got.equal(want))), - "got %v want %v", got, want) -} - -func assertBoolValue(t *testing.T, got Value, want bool) { - t.Helper() - testutilEquals[Value](t, got, Boolean(want)) -} - -func assertLongValue(t *testing.T, got Value, want int64) { - t.Helper() - testutilEquals[Value](t, got, Long(want)) -} - -func assertZeroValue(t *testing.T, got Value) { - t.Helper() - testutilEquals(t, got, zeroValue()) -} - -func assertValueString(t *testing.T, v Value, want string) { - t.Helper() - testutilEquals(t, v.String(), want) -} - -func safeDoErr(f func() error) (err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("%v", r) - } - }() - return f() -} diff --git a/toeval.go b/toeval.go index 1d2f1977..9988178e 100644 --- a/toeval.go +++ b/toeval.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "github.com/cedar-policy/cedar-go/types" "github.com/cedar-policy/cedar-go/x/exp/parser" ) @@ -21,7 +22,7 @@ func toEval(n any) evaler { var res evaler switch v.Type { case parser.MatchAny: - res = newLiteralEval(Boolean(true)) + res = newLiteralEval(types.Boolean(true)) case parser.MatchEquals: res = newEqualEval(newVariableEval(variableNamePrincipal), toEval(v.Entity)) case parser.MatchIn: @@ -38,7 +39,7 @@ func toEval(n any) evaler { var res evaler switch v.Type { case parser.MatchAny: - res = newLiteralEval(Boolean(true)) + res = newLiteralEval(types.Boolean(true)) case parser.MatchEquals: res = newEqualEval(newVariableEval(variableNameAction), toEval(v.Entities[0])) case parser.MatchIn: @@ -56,7 +57,7 @@ func toEval(n any) evaler { var res evaler switch v.Type { case parser.MatchAny: - res = newLiteralEval(Boolean(true)) + res = newLiteralEval(types.Boolean(true)) case parser.MatchEquals: res = newEqualEval(newVariableEval(variableNameResource), toEval(v.Entity)) case parser.MatchIn: @@ -70,9 +71,9 @@ func toEval(n any) evaler { } return res case parser.Entity: - return newLiteralEval(entityValueFromSlice(v.Path)) + return newLiteralEval(types.EntityValueFromSlice(v.Path)) case parser.Path: - return newLiteralEval(pathFromSlice(v.Path)) + return newLiteralEval(types.PathFromSlice(v.Path)) case parser.Condition: var res evaler switch v.Type { @@ -210,11 +211,11 @@ func toEval(n any) evaler { case parser.Literal: switch v.Type { case parser.LiteralBool: - return newLiteralEval(Boolean(v.Bool)) + return newLiteralEval(types.Boolean(v.Bool)) case parser.LiteralInt: - return newLiteralEval(Long(v.Long)) + return newLiteralEval(types.Long(v.Long)) case parser.LiteralString: - return newLiteralEval(String(v.Str)) + return newLiteralEval(types.String(v.Str)) default: panic("missing LiteralType case") } diff --git a/toeval_test.go b/toeval_test.go index 5a64a9d3..037e58ef 100644 --- a/toeval_test.go +++ b/toeval_test.go @@ -1,12 +1,24 @@ package cedar import ( + "fmt" "strings" "testing" + "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/types" "github.com/cedar-policy/cedar-go/x/exp/parser" ) +func safeDoErr(f func() error) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + } + }() + return f() +} + func TestToEval(t *testing.T) { t.Parallel() tests := []struct { @@ -18,7 +30,7 @@ func TestToEval(t *testing.T) { {"happy", parser.Entity{ Path: []string{"Action", "test"}, }, - newLiteralEval(entityValueFromSlice([]string{"Action", "test"})), ""}, + newLiteralEval(types.EntityValueFromSlice([]string{"Action", "test"})), ""}, {"missingRelOp", parser.Relation{ Add: parser.Add{ Mults: []parser.Mult{ @@ -148,10 +160,10 @@ func TestToEval(t *testing.T) { out = toEval(tt.in) return nil }) - testutilEquals(t, out, tt.out) - testutilEquals(t, err != nil, tt.panic != "") + testutil.Equals(t, out, tt.out) + testutil.Equals(t, err != nil, tt.panic != "") if tt.panic != "" { - testutilFatalIf(t, !strings.Contains(err.Error(), tt.panic), "panic got %v want %v", err.Error(), tt.panic) + testutil.FatalIf(t, !strings.Contains(err.Error(), tt.panic), "panic got %v want %v", err.Error(), tt.panic) } }) } diff --git a/json.go b/types/json.go similarity index 99% rename from json.go rename to types/json.go index 0ae74298..1df5e0f8 100644 --- a/json.go +++ b/types/json.go @@ -1,4 +1,4 @@ -package cedar +package types import ( "bytes" diff --git a/json_test.go b/types/json_test.go similarity index 70% rename from json_test.go rename to types/json_test.go index 4b62a2ad..c4e2d834 100644 --- a/json_test.go +++ b/types/json_test.go @@ -1,9 +1,11 @@ -package cedar +package types import ( "encoding/json" "fmt" "testing" + + "github.com/cedar-policy/cedar-go/testutil" ) func mustDecimalValue(v string) Decimal { @@ -28,15 +30,15 @@ func TestJSON_Value(t *testing.T) { {"explicitEntity", `{ "__entity": { "type": "User", "id": "alice" } }`, EntityUID{Type: "User", ID: "alice"}, nil}, {"impliedLongEntity", `{ "type": "User::External", "id": "alice" }`, EntityUID{Type: "User::External", ID: "alice"}, nil}, {"explicitLongEntity", `{ "__entity": { "type": "User::External", "id": "alice" } }`, EntityUID{Type: "User::External", ID: "alice"}, nil}, - {"invalidJSON", `!@#$`, zeroValue(), errJSONDecode}, - {"numericOverflow", "12341234123412341234", zeroValue(), errJSONLongOutOfRange}, - {"unsupportedNull", "null", zeroValue(), errJSONUnsupportedType}, + {"invalidJSON", `!@#$`, ZeroValue(), errJSONDecode}, + {"numericOverflow", "12341234123412341234", ZeroValue(), errJSONLongOutOfRange}, + {"unsupportedNull", "null", ZeroValue(), errJSONUnsupportedType}, {"explicitIP", `{ "__extn": { "fn": "ip", "arg": "222.222.222.7" } }`, mustIPValue("222.222.222.7"), nil}, {"explicitSubnet", `{ "__extn": { "fn": "ip", "arg": "192.168.0.0/16" } }`, mustIPValue("192.168.0.0/16"), nil}, {"explicitDecimal", `{ "__extn": { "fn": "decimal", "arg": "33.57" } }`, mustDecimalValue("33.57"), nil}, - {"invalidExtension", `{ "__extn": { "fn": "asdf", "arg": "blah" } }`, zeroValue(), errJSONInvalidExtn}, - {"badIP", `{ "__extn": { "fn": "ip", "arg": "bad" } }`, zeroValue(), errIP}, - {"badDecimal", `{ "__extn": { "fn": "decimal", "arg": "bad" } }`, zeroValue(), errDecimal}, + {"invalidExtension", `{ "__extn": { "fn": "asdf", "arg": "blah" } }`, ZeroValue(), errJSONInvalidExtn}, + {"badIP", `{ "__extn": { "fn": "ip", "arg": "bad" } }`, ZeroValue(), ErrIP}, + {"badDecimal", `{ "__extn": { "fn": "decimal", "arg": "bad" } }`, ZeroValue(), ErrDecimal}, {"set", `[42]`, Set{Long(42)}, nil}, {"record", `{"a":"b"}`, Record{"a": String("b")}, nil}, {"bool", `false`, Boolean(false), nil}, @@ -48,8 +50,8 @@ func TestJSON_Value(t *testing.T) { var got Value ptr := &got err := unmarshalJSON([]byte(tt.in), ptr) - assertError(t, err, tt.err) - assertValue(t, got, tt.want) + testutil.AssertError(t, err, tt.err) + AssertValue(t, got, tt.want) if tt.err != nil { return } @@ -57,12 +59,12 @@ func TestJSON_Value(t *testing.T) { // Now assert that when we Marshal/Unmarshal that value, we still // have what we started with gotJSON, err := (*ptr).ExplicitMarshalJSON() - testutilOK(t, err) + testutil.OK(t, err) var gotRetry Value ptr = &gotRetry err = unmarshalJSON(gotJSON, ptr) - testutilOK(t, err) - testutilEquals(t, gotRetry, tt.want) + testutil.OK(t, err) + testutil.Equals(t, gotRetry, tt.want) }) } } @@ -129,7 +131,7 @@ func TestTypedJSONUnmarshal(t *testing.T) { }, in: `{ "__extn": { "fn": "ip", "arg": "bad" } }`, wantValue: IPAddr{}, - wantErr: errIP, + wantErr: ErrIP, }, { name: "ip/badJSON", @@ -207,7 +209,7 @@ func TestTypedJSONUnmarshal(t *testing.T) { }, in: `{ "__extn": { "fn": "decimal", "arg": "bad" } }`, wantValue: Decimal(0), - wantErr: errDecimal, + wantErr: ErrDecimal, }, { name: "decimal/badJSON", @@ -248,8 +250,8 @@ func TestTypedJSONUnmarshal(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() gotValue, gotErr := tt.f([]byte(tt.in)) - testutilEquals(t, gotValue, tt.wantValue) - assertError(t, gotErr, tt.wantErr) + testutil.Equals(t, gotValue, tt.wantValue) + testutil.AssertError(t, gotErr, tt.wantErr) }) } } @@ -286,11 +288,11 @@ func TestJSONMarshal(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() outExplicit, err := tt.in.ExplicitMarshalJSON() - testutilOK(t, err) - testutilEquals(t, string(outExplicit), tt.outExplicit) + testutil.OK(t, err) + testutil.Equals(t, string(outExplicit), tt.outExplicit) outImplicit, err := json.Marshal(tt.in) - testutilOK(t, err) - testutilEquals(t, string(outImplicit), tt.outImplicit) + testutil.OK(t, err) + testutil.Equals(t, string(outImplicit), tt.outImplicit) }) } } @@ -299,9 +301,9 @@ type jsonErr struct{} func (j *jsonErr) String() string { return "" } func (j *jsonErr) Cedar() string { return "" } -func (j *jsonErr) equal(Value) bool { return false } +func (j *jsonErr) Equal(Value) bool { return false } func (j *jsonErr) ExplicitMarshalJSON() ([]byte, error) { return nil, fmt.Errorf("jsonErr") } -func (j *jsonErr) typeName() string { return "jsonErr" } +func (j *jsonErr) TypeName() string { return "jsonErr" } func (j *jsonErr) deepClone() Value { return nil } func TestJSONSet(t *testing.T) { @@ -310,13 +312,13 @@ func TestJSONSet(t *testing.T) { t.Parallel() var s Set err := json.Unmarshal([]byte(`[{"__extn":{"fn":"err"}}]`), &s) - testutilError(t, err) + testutil.Error(t, err) }) t.Run("MarshalErr", func(t *testing.T) { t.Parallel() s := Set{&jsonErr{}} _, err := json.Marshal(s) - testutilError(t, err) + testutil.Error(t, err) }) } @@ -326,7 +328,7 @@ func TestJSONRecord(t *testing.T) { t.Parallel() var r Record err := json.Unmarshal([]byte(`{"key":{"__extn":{"fn":"err"}}}`), &r) - testutilError(t, err) + testutil.Error(t, err) }) t.Run("MarshalKeyErrImpossible", func(t *testing.T) { t.Parallel() @@ -335,117 +337,13 @@ func TestJSONRecord(t *testing.T) { r[string(k)] = Boolean(false) v, err := json.Marshal(r) // this demonstrates that invalid keys will still result in json - testutilEquals(t, string(v), `{"\ufffd\u0001":false}`) - testutilOK(t, err) + testutil.Equals(t, string(v), `{"\ufffd\u0001":false}`) + testutil.OK(t, err) }) t.Run("MarshalValueErr", func(t *testing.T) { t.Parallel() r := Record{"key": &jsonErr{}} _, err := json.Marshal(r) - testutilError(t, err) - }) -} - -func TestEntitiesJSON(t *testing.T) { - t.Parallel() - t.Run("Marshal", func(t *testing.T) { - t.Parallel() - e := Entities{} - ent := Entity{ - UID: NewEntityUID("Type", "id"), - Parents: []EntityUID{}, - Attributes: Record{"key": Long(42)}, - } - e[ent.UID] = ent - b, err := e.MarshalJSON() - testutilOK(t, err) - testutilEquals(t, string(b), `[{"uid":{"type":"Type","id":"id"},"attrs":{"key":42}}]`) - }) - - t.Run("Unmarshal", func(t *testing.T) { - t.Parallel() - b := []byte(`[{"uid":{"type":"Type","id":"id"},"parents":[],"attrs":{"key":42}}]`) - var e Entities - err := json.Unmarshal(b, &e) - testutilOK(t, err) - want := Entities{} - ent := Entity{ - UID: NewEntityUID("Type", "id"), - Parents: []EntityUID{}, - Attributes: Record{"key": Long(42)}, - } - want[ent.UID] = ent - testutilEquals(t, e, want) - }) - - t.Run("UnmarshalErr", func(t *testing.T) { - t.Parallel() - var e Entities - err := e.UnmarshalJSON([]byte(`!@#$`)) - testutilError(t, err) - }) -} - -func TestJSONEffect(t *testing.T) { - t.Parallel() - t.Run("MarshalPermit", func(t *testing.T) { - t.Parallel() - e := Permit - b, err := e.MarshalJSON() - testutilOK(t, err) - testutilEquals(t, string(b), `"permit"`) - }) - t.Run("MarshalForbid", func(t *testing.T) { - t.Parallel() - e := Forbid - b, err := e.MarshalJSON() - testutilOK(t, err) - testutilEquals(t, string(b), `"forbid"`) - }) - t.Run("UnmarshalPermit", func(t *testing.T) { - t.Parallel() - var e Effect - err := json.Unmarshal([]byte(`"permit"`), &e) - testutilOK(t, err) - testutilEquals(t, e, Permit) - }) - t.Run("UnmarshalForbid", func(t *testing.T) { - t.Parallel() - var e Effect - err := json.Unmarshal([]byte(`"forbid"`), &e) - testutilOK(t, err) - testutilEquals(t, e, Forbid) - }) -} - -func TestJSONDecision(t *testing.T) { - t.Parallel() - t.Run("MarshalAllow", func(t *testing.T) { - t.Parallel() - d := Allow - b, err := d.MarshalJSON() - testutilOK(t, err) - testutilEquals(t, string(b), `"allow"`) - }) - t.Run("MarshalDeny", func(t *testing.T) { - t.Parallel() - d := Deny - b, err := d.MarshalJSON() - testutilOK(t, err) - testutilEquals(t, string(b), `"deny"`) - }) - t.Run("UnmarshalAllow", func(t *testing.T) { - t.Parallel() - var d Decision - err := json.Unmarshal([]byte(`"allow"`), &d) - testutilOK(t, err) - testutilEquals(t, d, Allow) - }) - t.Run("UnmarshalDeny", func(t *testing.T) { - t.Parallel() - var d Decision - err := json.Unmarshal([]byte(`"deny"`), &d) - testutilOK(t, err) - testutilEquals(t, d, Deny) + testutil.Error(t, err) }) } diff --git a/types/testutil.go b/types/testutil.go new file mode 100644 index 00000000..787f96b9 --- /dev/null +++ b/types/testutil.go @@ -0,0 +1,36 @@ +package types + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/testutil" +) + +func AssertValue(t *testing.T, got, want Value) { + t.Helper() + testutil.FatalIf( + t, + !((got == ZeroValue() && want == ZeroValue()) || + (got != ZeroValue() && want != ZeroValue() && got.Equal(want))), + "got %v want %v", got, want) +} + +func AssertBoolValue(t *testing.T, got Value, want bool) { + t.Helper() + testutil.Equals[Value](t, got, Boolean(want)) +} + +func AssertLongValue(t *testing.T, got Value, want int64) { + t.Helper() + testutil.Equals[Value](t, got, Long(want)) +} + +func AssertZeroValue(t *testing.T, got Value) { + t.Helper() + testutil.Equals(t, got, ZeroValue()) +} + +func AssertValueString(t *testing.T, v Value, want string) { + t.Helper() + testutil.Equals(t, v.String(), want) +} diff --git a/value.go b/types/value.go similarity index 79% rename from value.go rename to types/value.go index 07cef673..27304ddf 100644 --- a/value.go +++ b/types/value.go @@ -1,4 +1,4 @@ -package cedar +package types import ( "bytes" @@ -10,11 +10,14 @@ import ( "strings" "unicode" - "github.com/cedar-policy/cedar-go/x/exp/parser" "golang.org/x/exp/maps" "golang.org/x/exp/slices" ) +var ErrDecimal = fmt.Errorf("error parsing decimal value") +var ErrIP = fmt.Errorf("error parsing ip value") +var ErrType = fmt.Errorf("type error") + type Value interface { // String produces a string representation of the Value. String() string @@ -24,23 +27,23 @@ type Value interface { // applicable) JSON form, which is necessary for marshalling values within // Sets or Records where the type is not defined. ExplicitMarshalJSON() ([]byte, error) - equal(Value) bool - typeName() string + Equal(Value) bool + TypeName() string deepClone() Value } -func zeroValue() Value { +func ZeroValue() Value { return nil } // A Boolean is a value that is either true or false. type Boolean bool -func (a Boolean) equal(bi Value) bool { +func (a Boolean) Equal(bi Value) bool { b, ok := bi.(Boolean) return ok && a == b } -func (v Boolean) typeName() string { return "bool" } +func (v Boolean) TypeName() string { return "bool" } // String produces a string representation of the Boolean, e.g. `true`. func (v Boolean) String() string { return v.Cedar() } @@ -54,17 +57,25 @@ func (v Boolean) Cedar() string { func (v Boolean) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) } func (v Boolean) deepClone() Value { return v } +func ValueToBool(v Value) (Boolean, error) { + bv, ok := v.(Boolean) + if !ok { + return false, fmt.Errorf("%w: expected bool, got %v", ErrType, v.TypeName()) + } + return bv, nil +} + // A Long is a whole number without decimals that can range from -9223372036854775808 to 9223372036854775807. type Long int64 -func (a Long) equal(bi Value) bool { +func (a Long) Equal(bi Value) bool { b, ok := bi.(Long) return ok && a == b } // ExplicitMarshalJSON marshals the Long into JSON. func (v Long) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) } -func (v Long) typeName() string { return "long" } +func (v Long) TypeName() string { return "long" } // String produces a string representation of the Long, e.g. `42`. func (v Long) String() string { return v.Cedar() } @@ -75,17 +86,25 @@ func (v Long) Cedar() string { } func (v Long) deepClone() Value { return v } +func ValueToLong(v Value) (Long, error) { + lv, ok := v.(Long) + if !ok { + return 0, fmt.Errorf("%w: expected long, got %v", ErrType, v.TypeName()) + } + return lv, nil +} + // A String is a sequence of characters consisting of letters, numbers, or symbols. type String string -func (a String) equal(bi Value) bool { +func (a String) Equal(bi Value) bool { b, ok := bi.(String) return ok && a == b } // ExplicitMarshalJSON marshals the String into JSON. func (v String) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) } -func (v String) typeName() string { return "string" } +func (v String) TypeName() string { return "string" } // String produces an unquoted string representation of the String, e.g. `hello`. func (v String) String() string { @@ -94,37 +113,45 @@ func (v String) String() string { // Cedar produces a valid Cedar language representation of the String, e.g. `"hello"`. func (v String) Cedar() string { - return parser.FakeRustQuote(string(v)) + return strconv.Quote(string(v)) } func (v String) deepClone() Value { return v } +func ValueToString(v Value) (String, error) { + sv, ok := v.(String) + if !ok { + return "", fmt.Errorf("%w: expected string, got %v", ErrType, v.TypeName()) + } + return sv, nil +} + // A Set is a collection of elements that can be of the same or different types. type Set []Value -func (s Set) contains(v Value) bool { +func (s Set) Contains(v Value) bool { for _, e := range s { - if e.equal(v) { + if e.Equal(v) { return true } } return false } -// Equals returns true if the sets are equal. -func (s Set) Equals(b Set) bool { return s.equal(b) } +// Equals returns true if the sets are Equal. +func (s Set) Equals(b Set) bool { return s.Equal(b) } -func (as Set) equal(bi Value) bool { +func (as Set) Equal(bi Value) bool { bs, ok := bi.(Set) if !ok { return false } for _, a := range as { - if !bs.contains(a) { + if !bs.Contains(a) { return false } } for _, b := range bs { - if !as.contains(b) { + if !as.Contains(b) { return false } } @@ -170,7 +197,7 @@ func (v Set) MarshalJSON() ([]byte, error) { // explicit JSON form for all the values in the Set. func (v Set) ExplicitMarshalJSON() ([]byte, error) { return v.MarshalJSON() } -func (v Set) typeName() string { return "set" } +func (v Set) TypeName() string { return "set" } // String produces a string representation of the Set, e.g. `[1,2,3]`. func (v Set) String() string { return v.Cedar() } @@ -202,21 +229,29 @@ func (v Set) DeepClone() Set { return res } +func ValueToSet(v Value) (Set, error) { + sv, ok := v.(Set) + if !ok { + return nil, fmt.Errorf("%w: expected set, got %v", ErrType, v.TypeName()) + } + return sv, nil +} + // A Record is a collection of attributes. Each attribute consists of a name and // an associated value. Names are simple strings. Values can be of any type. type Record map[string]Value -// Equals returns true if the records are equal. -func (r Record) Equals(b Record) bool { return r.equal(b) } +// Equals returns true if the records are Equal. +func (r Record) Equals(b Record) bool { return r.Equal(b) } -func (a Record) equal(bi Value) bool { +func (a Record) Equal(bi Value) bool { b, ok := bi.(Record) if !ok || len(a) != len(b) { return false } for k, av := range a { bv, ok := b[k] - if !ok || !av.equal(bv) { + if !ok || !av.Equal(bv) { return false } } @@ -264,7 +299,7 @@ func (v Record) MarshalJSON() ([]byte, error) { // ExplicitMarshalJSON marshals the Record into JSON, the marshaller uses the // explicit JSON form for all the values in the Record. func (v Record) ExplicitMarshalJSON() ([]byte, error) { return v.MarshalJSON() } -func (r Record) typeName() string { return "record" } +func (r Record) TypeName() string { return "record" } // String produces a string representation of the Record, e.g. `{"a":1,"b":2,"c":3}`. func (r Record) String() string { return r.Cedar() } @@ -282,7 +317,7 @@ func (r Record) Cedar() string { sb.WriteString(",") } first = false - sb.WriteString(parser.FakeRustQuote(k)) + sb.WriteString(strconv.Quote(k)) sb.WriteString(":") sb.WriteString(v.Cedar()) } @@ -303,6 +338,14 @@ func (v Record) DeepClone() Record { return res } +func ValueToRecord(v Value) (Record, error) { + rv, ok := v.(Record) + if !ok { + return nil, fmt.Errorf("%w: expected record got %v", ErrType, v.TypeName()) + } + return rv, nil +} + // An EntityUID is the identifier for a principal, action, or resource. type EntityUID struct { Type string @@ -321,18 +364,18 @@ func (a EntityUID) IsZero() bool { return a.Type == "" && a.ID == "" } -func (a EntityUID) equal(bi Value) bool { +func (a EntityUID) Equal(bi Value) bool { b, ok := bi.(EntityUID) return ok && a == b } -func (v EntityUID) typeName() string { return fmt.Sprintf("(entity of type `%s`)", v.Type) } +func (v EntityUID) TypeName() string { return fmt.Sprintf("(entity of type `%s`)", v.Type) } // String produces a string representation of the EntityUID, e.g. `Type::"id"`. func (v EntityUID) String() string { return v.Cedar() } // Cedar produces a valid Cedar language representation of the EntityUID, e.g. `Type::"id"`. func (v EntityUID) Cedar() string { - return v.Type + "::" + parser.FakeRustQuote(v.ID) + return v.Type + "::" + strconv.Quote(v.ID) } func (v *EntityUID) UnmarshalJSON(b []byte) error { @@ -372,29 +415,45 @@ func (v EntityUID) ExplicitMarshalJSON() ([]byte, error) { } func (v EntityUID) deepClone() Value { return v } -func entityValueFromSlice(v []string) EntityUID { +func ValueToEntity(v Value) (EntityUID, error) { + ev, ok := v.(EntityUID) + if !ok { + return EntityUID{}, fmt.Errorf("%w: expected (entity of type `any_entity_type`), got %v", ErrType, v.TypeName()) + } + return ev, nil +} + +func EntityValueFromSlice(v []string) EntityUID { return EntityUID{ Type: strings.Join(v[:len(v)-1], "::"), ID: v[len(v)-1], } } -// path is the type portion of an EntityUID -type path string +// Path is the type portion of an EntityUID +type Path string -func (a path) equal(bi Value) bool { - b, ok := bi.(path) +func (a Path) Equal(bi Value) bool { + b, ok := bi.(Path) return ok && a == b } -func (v path) typeName() string { return fmt.Sprintf("(path of type `%s`)", v) } +func (v Path) TypeName() string { return fmt.Sprintf("(Path of type `%s`)", v) } + +func (v Path) String() string { return string(v) } +func (v Path) Cedar() string { return string(v) } +func (v Path) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(string(v)) } +func (v Path) deepClone() Value { return v } -func (v path) String() string { return string(v) } -func (v path) Cedar() string { return string(v) } -func (v path) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(string(v)) } -func (v path) deepClone() Value { return v } +func ValueToPath(v Value) (Path, error) { + ev, ok := v.(Path) + if !ok { + return "", fmt.Errorf("%w: expected (Path of type `any_entity_type`), got %v", ErrType, v.TypeName()) + } + return ev, nil +} -func pathFromSlice(v []string) path { - return path(strings.Join(v, "::")) +func PathFromSlice(v []string) Path { + return Path(strings.Join(v, "::")) } // A Decimal is a value with both a whole number part and a decimal part of no @@ -409,7 +468,7 @@ const DecimalPrecision = 10000 func ParseDecimal(s string) (Decimal, error) { // Check for empty string. if len(s) == 0 { - return Decimal(0), fmt.Errorf("%w: string too short", errDecimal) + return Decimal(0), fmt.Errorf("%w: string too short", ErrDecimal) } i := 0 @@ -419,14 +478,14 @@ func ParseDecimal(s string) (Decimal, error) { negative = true i++ if i == len(s) { - return Decimal(0), fmt.Errorf("%w: string too short", errDecimal) + return Decimal(0), fmt.Errorf("%w: string too short", ErrDecimal) } } // Parse the required first digit. c := rune(s[i]) if !unicode.IsDigit(c) { - return Decimal(0), fmt.Errorf("%w: unexpected character %s", errDecimal, strconv.QuoteRune(c)) + return Decimal(0), fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) } integer := int64(c - '0') i++ @@ -434,18 +493,18 @@ func ParseDecimal(s string) (Decimal, error) { // Parse any other digits, ending with i pointing to '.'. for ; ; i++ { if i == len(s) { - return Decimal(0), fmt.Errorf("%w: string missing decimal point", errDecimal) + return Decimal(0), fmt.Errorf("%w: string missing decimal point", ErrDecimal) } c = rune(s[i]) if c == '.' { break } if !unicode.IsDigit(c) { - return Decimal(0), fmt.Errorf("%w: unexpected character %s", errDecimal, strconv.QuoteRune(c)) + return Decimal(0), fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) } integer = 10*integer + int64(c-'0') if integer > 922337203685477 { - return Decimal(0), fmt.Errorf("%w: overflow", errDecimal) + return Decimal(0), fmt.Errorf("%w: overflow", ErrDecimal) } } @@ -458,7 +517,7 @@ func ParseDecimal(s string) (Decimal, error) { for ; i < len(s); i++ { c = rune(s[i]) if !unicode.IsDigit(c) { - return Decimal(0), fmt.Errorf("%w: unexpected character %s", errDecimal, strconv.QuoteRune(c)) + return Decimal(0), fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) } fraction = 10*fraction + int64(c-'0') fractionDigits++ @@ -467,7 +526,7 @@ func ParseDecimal(s string) (Decimal, error) { // Adjust the fraction part based on how many digits we parsed. switch fractionDigits { case 0: - return Decimal(0), fmt.Errorf("%w: missing digits after decimal point", errDecimal) + return Decimal(0), fmt.Errorf("%w: missing digits after decimal point", ErrDecimal) case 1: fraction *= 1000 case 2: @@ -476,12 +535,12 @@ func ParseDecimal(s string) (Decimal, error) { fraction *= 10 case 4: default: - return Decimal(0), fmt.Errorf("%w: too many digits after decimal point", errDecimal) + return Decimal(0), fmt.Errorf("%w: too many digits after decimal point", ErrDecimal) } // Check for overflow before we put the number together. if integer >= 922337203685477 && (fraction > 5808 || (!negative && fraction == 5808)) { - return Decimal(0), fmt.Errorf("%w: overflow", errDecimal) + return Decimal(0), fmt.Errorf("%w: overflow", ErrDecimal) } // Put the number together. @@ -496,12 +555,12 @@ func ParseDecimal(s string) (Decimal, error) { } } -func (a Decimal) equal(bi Value) bool { +func (a Decimal) Equal(bi Value) bool { b, ok := bi.(Decimal) return ok && a == b } -func (v Decimal) typeName() string { return "decimal" } +func (v Decimal) TypeName() string { return "decimal" } // Cedar produces a valid Cedar language representation of the Decimal, e.g. `decimal("12.34")`. func (v Decimal) Cedar() string { return `decimal("` + v.String() + `")` } @@ -573,6 +632,14 @@ func (v Decimal) ExplicitMarshalJSON() ([]byte, error) { } func (v Decimal) deepClone() Value { return v } +func ValueToDecimal(v Value) (Decimal, error) { + d, ok := v.(Decimal) + if !ok { + return 0, fmt.Errorf("%w: expected decimal, got %v", ErrType, v.TypeName()) + } + return d, nil +} + // An IPAddr is value that represents an IP address. It can be either IPv4 or IPv6. // The value can represent an individual address or a range of addresses. type IPAddr netip.Prefix @@ -581,22 +648,22 @@ type IPAddr netip.Prefix func ParseIPAddr(s string) (IPAddr, error) { // We disallow IPv4-mapped IPv6 addresses in dotted notation because Cedar does. if strings.Count(s, ":") >= 2 && strings.Count(s, ".") >= 2 { - return IPAddr{}, fmt.Errorf("%w: cannot parse IPv4 addresses embedded in IPv6 addresses", errIP) + return IPAddr{}, fmt.Errorf("%w: cannot parse IPv4 addresses embedded in IPv6 addresses", ErrIP) } else if net, err := netip.ParsePrefix(s); err == nil { return IPAddr(net), nil } else if addr, err := netip.ParseAddr(s); err == nil { return IPAddr(netip.PrefixFrom(addr, addr.BitLen())), nil } else { - return IPAddr{}, fmt.Errorf("%w: error parsing IP address %s", errIP, s) + return IPAddr{}, fmt.Errorf("%w: error parsing IP address %s", ErrIP, s) } } -func (a IPAddr) equal(bi Value) bool { +func (a IPAddr) Equal(bi Value) bool { b, ok := bi.(IPAddr) return ok && a == b } -func (v IPAddr) typeName() string { return "IP" } +func (v IPAddr) TypeName() string { return "IP" } // Cedar produces a valid Cedar language representation of the IPAddr, e.g. `ip("127.0.0.1")`. func (v IPAddr) Cedar() string { return `ip("` + v.String() + `")` } @@ -613,15 +680,15 @@ func (v IPAddr) Prefix() netip.Prefix { return netip.Prefix(v) } -func (v IPAddr) isIPv4() bool { +func (v IPAddr) IsIPv4() bool { return v.Addr().Is4() } -func (v IPAddr) isIPv6() bool { +func (v IPAddr) IsIPv6() bool { return v.Addr().Is6() } -func (v IPAddr) isLoopback() bool { +func (v IPAddr) IsLoopback() bool { // This comment is in the Cedar Rust implementation: // // Loopback addresses are "127.0.0.0/8" for IpV4 and "::1" for IpV6 @@ -640,7 +707,7 @@ func (v IPAddr) Addr() netip.Addr { return netip.Prefix(v).Addr() } -func (v IPAddr) isMulticast() bool { +func (v IPAddr) IsMulticast() bool { // This comment is in the Cedar Rust implementation: // // Multicast addresses are "224.0.0.0/4" for IpV4 and "ff00::/8" for @@ -654,7 +721,7 @@ func (v IPAddr) isMulticast() bool { // range `ip2/prefix2`, then `ip1` is in `ip2/prefix2` and `prefix1 >= // prefix2` var min_prefix_len int - if v.isIPv4() { + if v.IsIPv4() { min_prefix_len = 4 } else { min_prefix_len = 8 @@ -662,7 +729,7 @@ func (v IPAddr) isMulticast() bool { return v.Addr().IsMulticast() && v.Prefix().Bits() >= min_prefix_len } -func (c IPAddr) contains(o IPAddr) bool { +func (c IPAddr) Contains(o IPAddr) bool { return c.Prefix().Contains(o.Addr()) && c.Prefix().Bits() <= o.Prefix().Bits() } @@ -721,3 +788,11 @@ func (v IPAddr) ExplicitMarshalJSON() ([]byte, error) { // in this case, netip.Prefix does contain a pointer, but // the interface given is immutable, so it is safe to return func (v IPAddr) deepClone() Value { return v } + +func ValueToIP(v Value) (IPAddr, error) { + i, ok := v.(IPAddr) + if !ok { + return IPAddr{}, fmt.Errorf("%w: expected ipaddr, got %v", ErrType, v.TypeName()) + } + return i, nil +} diff --git a/value_test.go b/types/value_test.go similarity index 62% rename from value_test.go rename to types/value_test.go index 08eadf2b..9381e78e 100644 --- a/value_test.go +++ b/types/value_test.go @@ -1,48 +1,50 @@ -package cedar +package types import ( "fmt" "testing" + + "github.com/cedar-policy/cedar-go/testutil" ) func TestBool(t *testing.T) { t.Parallel() t.Run("roundTrip", func(t *testing.T) { t.Parallel() - v, err := valueToBool(Boolean(true)) - testutilOK(t, err) - testutilEquals(t, v, true) + v, err := ValueToBool(Boolean(true)) + testutil.OK(t, err) + testutil.Equals(t, v, true) }) t.Run("toBoolOnNonBool", func(t *testing.T) { t.Parallel() - v, err := valueToBool(Long(0)) - assertError(t, err, errType) - testutilEquals(t, v, false) + v, err := ValueToBool(Long(0)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, false) }) - t.Run("equal", func(t *testing.T) { + t.Run("Equal", func(t *testing.T) { t.Parallel() t1 := Boolean(true) t2 := Boolean(true) f := Boolean(false) zero := Long(0) - testutilFatalIf(t, !t1.equal(t1), "%v not equal to %v", t1, t1) - testutilFatalIf(t, !t1.equal(t2), "%v not equal to %v", t1, t2) - testutilFatalIf(t, t1.equal(f), "%v equal to %v", t1, f) - testutilFatalIf(t, f.equal(t1), "%v equal to %v", f, t1) - testutilFatalIf(t, f.equal(zero), "%v equal to %v", f, zero) + testutil.FatalIf(t, !t1.Equal(t1), "%v not Equal to %v", t1, t1) + testutil.FatalIf(t, !t1.Equal(t2), "%v not Equal to %v", t1, t2) + testutil.FatalIf(t, t1.Equal(f), "%v Equal to %v", t1, f) + testutil.FatalIf(t, f.Equal(t1), "%v Equal to %v", f, t1) + testutil.FatalIf(t, f.Equal(zero), "%v Equal to %v", f, zero) }) t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, Boolean(true), "true") + AssertValueString(t, Boolean(true), "true") }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - tn := Boolean(true).typeName() - testutilEquals(t, tn, "bool") + tn := Boolean(true).TypeName() + testutil.Equals(t, tn, "bool") }) } @@ -50,40 +52,40 @@ func TestLong(t *testing.T) { t.Parallel() t.Run("roundTrip", func(t *testing.T) { t.Parallel() - v, err := valueToLong(Long(42)) - testutilOK(t, err) - testutilEquals(t, v, 42) + v, err := ValueToLong(Long(42)) + testutil.OK(t, err) + testutil.Equals(t, v, 42) }) t.Run("toLongOnNonLong", func(t *testing.T) { t.Parallel() - v, err := valueToLong(Boolean(true)) - assertError(t, err, errType) - testutilEquals(t, v, 0) + v, err := ValueToLong(Boolean(true)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, 0) }) - t.Run("equal", func(t *testing.T) { + t.Run("Equal", func(t *testing.T) { t.Parallel() one := Long(1) one2 := Long(1) zero := Long(0) f := Boolean(false) - testutilFatalIf(t, !one.equal(one), "%v not equal to %v", one, one) - testutilFatalIf(t, !one.equal(one2), "%v not equal to %v", one, one2) - testutilFatalIf(t, one.equal(zero), "%v equal to %v", one, zero) - testutilFatalIf(t, zero.equal(one), "%v equal to %v", zero, one) - testutilFatalIf(t, zero.equal(f), "%v equal to %v", zero, f) + testutil.FatalIf(t, !one.Equal(one), "%v not Equal to %v", one, one) + testutil.FatalIf(t, !one.Equal(one2), "%v not Equal to %v", one, one2) + testutil.FatalIf(t, one.Equal(zero), "%v Equal to %v", one, zero) + testutil.FatalIf(t, zero.Equal(one), "%v Equal to %v", zero, one) + testutil.FatalIf(t, zero.Equal(f), "%v Equal to %v", zero, f) }) t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, Long(1), "1") + AssertValueString(t, Long(1), "1") }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - tn := Long(1).typeName() - testutilEquals(t, tn, "long") + tn := Long(1).TypeName() + testutil.Equals(t, tn, "long") }) } @@ -91,38 +93,38 @@ func TestString(t *testing.T) { t.Parallel() t.Run("roundTrip", func(t *testing.T) { t.Parallel() - v, err := valueToString(String("hello")) - testutilOK(t, err) - testutilEquals(t, v, "hello") + v, err := ValueToString(String("hello")) + testutil.OK(t, err) + testutil.Equals(t, v, "hello") }) t.Run("toStringOnNonString", func(t *testing.T) { t.Parallel() - v, err := valueToString(Boolean(true)) - assertError(t, err, errType) - testutilEquals(t, v, "") + v, err := ValueToString(Boolean(true)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, "") }) - t.Run("equal", func(t *testing.T) { + t.Run("Equal", func(t *testing.T) { t.Parallel() hello := String("hello") hello2 := String("hello") goodbye := String("goodbye") - testutilFatalIf(t, !hello.equal(hello), "%v not equal to %v", hello, hello) - testutilFatalIf(t, !hello.equal(hello2), "%v not equal to %v", hello, hello2) - testutilFatalIf(t, hello.equal(goodbye), "%v equal to %v", hello, goodbye) + testutil.FatalIf(t, !hello.Equal(hello), "%v not Equal to %v", hello, hello) + testutil.FatalIf(t, !hello.Equal(hello2), "%v not Equal to %v", hello, hello2) + testutil.FatalIf(t, hello.Equal(goodbye), "%v Equal to %v", hello, goodbye) }) t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, String("hello"), `hello`) - assertValueString(t, String("hello\ngoodbye"), "hello\ngoodbye") + AssertValueString(t, String("hello"), `hello`) + AssertValueString(t, String("hello\ngoodbye"), "hello\ngoodbye") }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - tn := String("hello").typeName() - testutilEquals(t, tn, "string") + tn := String("hello").TypeName() + testutil.Equals(t, tn, "string") }) } @@ -131,20 +133,20 @@ func TestSet(t *testing.T) { t.Run("roundTrip", func(t *testing.T) { t.Parallel() v := Set{Boolean(true), Long(1)} - slice, err := valueToSet(v) - testutilOK(t, err) + slice, err := ValueToSet(v) + testutil.OK(t, err) v2 := slice - testutilFatalIf(t, !v.equal(v2), "got %v want %v", v, v2) + testutil.FatalIf(t, !v.Equal(v2), "got %v want %v", v, v2) }) t.Run("ToSetOnNonSet", func(t *testing.T) { t.Parallel() - v, err := valueToSet(Boolean(true)) - assertError(t, err, errType) - testutilEquals(t, v, nil) + v, err := ValueToSet(Boolean(true)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, nil) }) - t.Run("equal", func(t *testing.T) { + t.Run("Equal", func(t *testing.T) { t.Parallel() empty := Set{} empty2 := Set{} @@ -162,34 +164,34 @@ func TestSet(t *testing.T) { Long(3), Long(2), Long(2), Long(1), } - testutilFatalIf(t, !empty.Equals(empty), "%v not equal to %v", empty, empty) - testutilFatalIf(t, !empty.Equals(empty2), "%v not equal to %v", empty, empty2) - testutilFatalIf(t, !oneTrue.Equals(oneTrue), "%v not equal to %v", oneTrue, oneTrue) - testutilFatalIf(t, !oneTrue.Equals(oneTrue2), "%v not equal to %v", oneTrue, oneTrue2) - testutilFatalIf(t, !nestedOnce.Equals(nestedOnce), "%v not equal to %v", nestedOnce, nestedOnce) - testutilFatalIf(t, !nestedOnce.Equals(nestedOnce2), "%v not equal to %v", nestedOnce, nestedOnce2) - testutilFatalIf(t, !nestedTwice.Equals(nestedTwice), "%v not equal to %v", nestedTwice, nestedTwice) - testutilFatalIf(t, !nestedTwice.Equals(nestedTwice2), "%v not equal to %v", nestedTwice, nestedTwice2) - testutilFatalIf(t, !oneTwoThree.Equals(threeTwoTwoOne), "%v not equal to %v", oneTwoThree, threeTwoTwoOne) + testutil.FatalIf(t, !empty.Equals(empty), "%v not Equal to %v", empty, empty) + testutil.FatalIf(t, !empty.Equals(empty2), "%v not Equal to %v", empty, empty2) + testutil.FatalIf(t, !oneTrue.Equals(oneTrue), "%v not Equal to %v", oneTrue, oneTrue) + testutil.FatalIf(t, !oneTrue.Equals(oneTrue2), "%v not Equal to %v", oneTrue, oneTrue2) + testutil.FatalIf(t, !nestedOnce.Equals(nestedOnce), "%v not Equal to %v", nestedOnce, nestedOnce) + testutil.FatalIf(t, !nestedOnce.Equals(nestedOnce2), "%v not Equal to %v", nestedOnce, nestedOnce2) + testutil.FatalIf(t, !nestedTwice.Equals(nestedTwice), "%v not Equal to %v", nestedTwice, nestedTwice) + testutil.FatalIf(t, !nestedTwice.Equals(nestedTwice2), "%v not Equal to %v", nestedTwice, nestedTwice2) + testutil.FatalIf(t, !oneTwoThree.Equals(threeTwoTwoOne), "%v not Equal to %v", oneTwoThree, threeTwoTwoOne) - testutilFatalIf(t, empty.Equals(oneFalse), "%v equal to %v", empty, oneFalse) - testutilFatalIf(t, oneTrue.Equals(oneFalse), "%v equal to %v", oneTrue, oneFalse) - testutilFatalIf(t, nestedOnce.Equals(nestedTwice), "%v equal to %v", nestedOnce, nestedTwice) + testutil.FatalIf(t, empty.Equals(oneFalse), "%v Equal to %v", empty, oneFalse) + testutil.FatalIf(t, oneTrue.Equals(oneFalse), "%v Equal to %v", oneTrue, oneFalse) + testutil.FatalIf(t, nestedOnce.Equals(nestedTwice), "%v Equal to %v", nestedOnce, nestedTwice) }) t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, Set{}, "[]") - assertValueString( + AssertValueString(t, Set{}, "[]") + AssertValueString( t, Set{Boolean(true), Long(1)}, "[true,1]") }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - tn := Set{}.typeName() - testutilEquals(t, tn, "set") + tn := Set{}.TypeName() + testutil.Equals(t, tn, "set") }) } @@ -201,20 +203,20 @@ func TestRecord(t *testing.T) { "foo": Boolean(true), "bar": Long(1), } - map_, err := valueToRecord(v) - testutilOK(t, err) + map_, err := ValueToRecord(v) + testutil.OK(t, err) v2 := map_ - testutilFatalIf(t, !v.equal(v2), "got %v want %v", v, v2) + testutil.FatalIf(t, !v.Equal(v2), "got %v want %v", v, v2) }) t.Run("toRecordOnNonRecord", func(t *testing.T) { t.Parallel() - v, err := valueToRecord(String("hello")) - assertError(t, err, errType) - testutilEquals(t, v, nil) + v, err := ValueToRecord(String("hello")) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, nil) }) - t.Run("equal", func(t *testing.T) { + t.Run("Equal", func(t *testing.T) { t.Parallel() empty := Record{} empty2 := Record{} @@ -245,28 +247,28 @@ func TestRecord(t *testing.T) { "nest": twoElems, } - testutilFatalIf(t, !empty.Equals(empty), "%v not equal to %v", empty, empty) - testutilFatalIf(t, !empty.Equals(empty2), "%v not equal to %v", empty, empty2) + testutil.FatalIf(t, !empty.Equals(empty), "%v not Equal to %v", empty, empty) + testutil.FatalIf(t, !empty.Equals(empty2), "%v not Equal to %v", empty, empty2) - testutilFatalIf(t, !twoElems.Equals(twoElems), "%v not equal to %v", twoElems, twoElems) - testutilFatalIf(t, !twoElems.Equals(twoElems2), "%v not equal to %v", twoElems, twoElems2) + testutil.FatalIf(t, !twoElems.Equals(twoElems), "%v not Equal to %v", twoElems, twoElems) + testutil.FatalIf(t, !twoElems.Equals(twoElems2), "%v not Equal to %v", twoElems, twoElems2) - testutilFatalIf(t, !nested.Equals(nested), "%v not equal to %v", nested, nested) - testutilFatalIf(t, !nested.Equals(nested2), "%v not equal to %v", nested, nested2) + testutil.FatalIf(t, !nested.Equals(nested), "%v not Equal to %v", nested, nested) + testutil.FatalIf(t, !nested.Equals(nested2), "%v not Equal to %v", nested, nested2) - testutilFatalIf(t, nested.Equals(twoElems), "%v equal to %v", nested, twoElems) - testutilFatalIf(t, twoElems.Equals(differentValues), "%v equal to %v", twoElems, differentValues) - testutilFatalIf(t, twoElems.Equals(differentKeys), "%v equal to %v", twoElems, differentKeys) + testutil.FatalIf(t, nested.Equals(twoElems), "%v Equal to %v", nested, twoElems) + testutil.FatalIf(t, twoElems.Equals(differentValues), "%v Equal to %v", twoElems, differentValues) + testutil.FatalIf(t, twoElems.Equals(differentKeys), "%v Equal to %v", twoElems, differentKeys) }) t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, Record{}, "{}") - assertValueString( + AssertValueString(t, Record{}, "{}") + AssertValueString( t, Record{"foo": Boolean(true)}, `{"foo":true}`) - assertValueString( + AssertValueString( t, Record{ "foo": Boolean(true), @@ -275,10 +277,10 @@ func TestRecord(t *testing.T) { `{"bar":"blah","foo":true}`) }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - tn := Record{}.typeName() - testutilEquals(t, tn, "record") + tn := Record{}.TypeName() + testutil.Equals(t, tn, "record") }) } @@ -287,37 +289,37 @@ func TestEntity(t *testing.T) { t.Run("roundTrip", func(t *testing.T) { t.Parallel() want := EntityUID{Type: "User", ID: "bananas"} - v, err := valueToEntity(want) - testutilOK(t, err) - testutilEquals(t, v, want) + v, err := ValueToEntity(want) + testutil.OK(t, err) + testutil.Equals(t, v, want) }) t.Run("ToEntityOnNonEntity", func(t *testing.T) { t.Parallel() - v, err := valueToEntity(String("hello")) - assertError(t, err, errType) - testutilEquals(t, v, EntityUID{}) + v, err := ValueToEntity(String("hello")) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, EntityUID{}) }) - t.Run("equal", func(t *testing.T) { + t.Run("Equal", func(t *testing.T) { t.Parallel() twoElems := EntityUID{"type", "id"} twoElems2 := EntityUID{"type", "id"} differentValues := EntityUID{"asdf", "vfds"} - testutilFatalIf(t, !twoElems.equal(twoElems), "%v not equal to %v", twoElems, twoElems) - testutilFatalIf(t, !twoElems.equal(twoElems2), "%v not equal to %v", twoElems, twoElems2) - testutilFatalIf(t, twoElems.equal(differentValues), "%v equal to %v", twoElems, differentValues) + testutil.FatalIf(t, !twoElems.Equal(twoElems), "%v not Equal to %v", twoElems, twoElems) + testutil.FatalIf(t, !twoElems.Equal(twoElems2), "%v not Equal to %v", twoElems, twoElems2) + testutil.FatalIf(t, twoElems.Equal(differentValues), "%v Equal to %v", twoElems, differentValues) }) t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, EntityUID{Type: "type", ID: "id"}, `type::"id"`) - assertValueString(t, EntityUID{Type: "namespace::type", ID: "id"}, `namespace::type::"id"`) + AssertValueString(t, EntityUID{Type: "type", ID: "id"}, `type::"id"`) + AssertValueString(t, EntityUID{Type: "namespace::type", ID: "id"}, `namespace::type::"id"`) }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - tn := EntityUID{"T", "id"}.typeName() - testutilEquals(t, tn, "(entity of type `T`)") + tn := EntityUID{"T", "id"}.TypeName() + testutil.Equals(t, tn, "(entity of type `T`)") }) } @@ -378,8 +380,8 @@ func TestDecimal(t *testing.T) { t.Run(fmt.Sprintf("%s->%s", tt.in, tt.out), func(t *testing.T) { t.Parallel() d, err := ParseDecimal(tt.in) - testutilOK(t, err) - testutilEquals(t, d.String(), tt.out) + testutil.OK(t, err) + testutil.Equals(t, d.String(), tt.out) }) } } @@ -414,8 +416,8 @@ func TestDecimal(t *testing.T) { t.Run(fmt.Sprintf("%s->%s", tt.in, tt.errStr), func(t *testing.T) { t.Parallel() _, err := ParseDecimal(tt.in) - assertError(t, err, errDecimal) - testutilEquals(t, err.Error(), tt.errStr) + testutil.AssertError(t, err, ErrDecimal) + testutil.Equals(t, err.Error(), tt.errStr) }) } } @@ -423,36 +425,36 @@ func TestDecimal(t *testing.T) { t.Run("roundTrip", func(t *testing.T) { t.Parallel() dv, err := ParseDecimal("1.20") - testutilOK(t, err) - v, err := valueToDecimal(dv) - testutilOK(t, err) - testutilFatalIf(t, !v.equal(dv), "got %v want %v", v, dv) + testutil.OK(t, err) + v, err := ValueToDecimal(dv) + testutil.OK(t, err) + testutil.FatalIf(t, !v.Equal(dv), "got %v want %v", v, dv) }) t.Run("toDecimalOnNonDecimal", func(t *testing.T) { t.Parallel() - v, err := valueToDecimal(Boolean(true)) - assertError(t, err, errType) - testutilEquals(t, v, 0) + v, err := ValueToDecimal(Boolean(true)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, 0) }) - t.Run("equal", func(t *testing.T) { + t.Run("Equal", func(t *testing.T) { t.Parallel() one := Decimal(10000) one2 := Decimal(10000) zero := Decimal(0) f := Boolean(false) - testutilFatalIf(t, !one.equal(one), "%v not equal to %v", one, one) - testutilFatalIf(t, !one.equal(one2), "%v not equal to %v", one, one2) - testutilFatalIf(t, one.equal(zero), "%v equal to %v", one, zero) - testutilFatalIf(t, zero.equal(one), "%v equal to %v", zero, one) - testutilFatalIf(t, zero.equal(f), "%v equal to %v", zero, f) + testutil.FatalIf(t, !one.Equal(one), "%v not Equal to %v", one, one) + testutil.FatalIf(t, !one.Equal(one2), "%v not Equal to %v", one, one2) + testutil.FatalIf(t, one.Equal(zero), "%v Equal to %v", one, zero) + testutil.FatalIf(t, zero.Equal(one), "%v Equal to %v", zero, one) + testutil.FatalIf(t, zero.Equal(f), "%v Equal to %v", zero, f) }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - tn := Decimal(0).typeName() - testutilEquals(t, tn, "decimal") + tn := Decimal(0).TypeName() + testutil.Equals(t, tn, "decimal") }) } @@ -500,10 +502,10 @@ func TestIP(t *testing.T) { t.Parallel() i, err := ParseIPAddr(tt.in) if tt.parses { - testutilOK(t, err) - testutilEquals(t, i.String(), tt.out) + testutil.OK(t, err) + testutil.Equals(t, i.String(), tt.out) } else { - testutilError(t, err) + testutil.Error(t, err) } }) } @@ -511,9 +513,9 @@ func TestIP(t *testing.T) { t.Run("toIPOnNonIP", func(t *testing.T) { t.Parallel() - v, err := valueToIP(Boolean(true)) - assertError(t, err, errType) - testutilEquals(t, v, IPAddr{}) + v, err := ValueToIP(Boolean(true)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, IPAddr{}) }) t.Run("Equal", func(t *testing.T) { @@ -547,25 +549,25 @@ func TestIP(t *testing.T) { } for _, tt := range tests { tt := tt - t.Run(fmt.Sprintf("ip(%v).equal(ip(%v))", tt.lhs, tt.rhs), func(t *testing.T) { + t.Run(fmt.Sprintf("ip(%v).Equal(ip(%v))", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() lhs, err := ParseIPAddr(tt.lhs) - testutilOK(t, err) + testutil.OK(t, err) rhs, err := ParseIPAddr(tt.rhs) - testutilOK(t, err) - equal := lhs.equal(rhs) + testutil.OK(t, err) + equal := lhs.Equal(rhs) if equal != tt.equal { - t.Fatalf("expected ip(%v).equal(ip(%v)) to be %v instead of %v", tt.lhs, tt.rhs, tt.equal, equal) + t.Fatalf("expected ip(%v).Equal(ip(%v)) to be %v instead of %v", tt.lhs, tt.rhs, tt.equal, equal) } if equal { - testutilFatalIf( + testutil.FatalIf( t, - !lhs.contains(rhs), - "ip(%v) and ip(%v) compare equal but !ip(%v).contains(ip(%v))", tt.lhs, tt.rhs, tt.lhs, tt.rhs) - testutilFatalIf( + !lhs.Contains(rhs), + "ip(%v) and ip(%v) compare Equal but !ip(%v).contains(ip(%v))", tt.lhs, tt.rhs, tt.lhs, tt.rhs) + testutil.FatalIf( t, - !rhs.contains(lhs), - "ip(%v) and ip(%v) compare equal but !ip(%v).contains(ip(%v))", tt.rhs, tt.lhs, tt.rhs, tt.lhs) + !rhs.Contains(lhs), + "ip(%v) and ip(%v) compare Equal but !ip(%v).contains(ip(%v))", tt.rhs, tt.lhs, tt.rhs, tt.lhs) } }) } @@ -598,12 +600,12 @@ func TestIP(t *testing.T) { t.Run(fmt.Sprintf("ip(%v).isIPv{4,6}()", tt.val), func(t *testing.T) { t.Parallel() val, err := ParseIPAddr(tt.val) - testutilOK(t, err) - isIPv4 := val.isIPv4() + testutil.OK(t, err) + isIPv4 := val.IsIPv4() if isIPv4 != tt.isIPv4 { t.Fatalf("expected ip(%v).isIPv4() to be %v instead of %v", tt.val, tt.isIPv4, isIPv4) } - isIPv6 := val.isIPv6() + isIPv6 := val.IsIPv6() if isIPv6 != tt.isIPv6 { t.Fatalf("expected ip(%v).isIPv6() to be %v instead of %v", tt.val, tt.isIPv6, isIPv6) } @@ -647,8 +649,8 @@ func TestIP(t *testing.T) { t.Run(fmt.Sprintf("ip(%v).isLoopback()", tt.val), func(t *testing.T) { t.Parallel() val, err := ParseIPAddr(tt.val) - testutilOK(t, err) - isLoopback := val.isLoopback() + testutil.OK(t, err) + isLoopback := val.IsLoopback() if isLoopback != tt.isLoopback { t.Fatalf("expected ip(%v).isLoopback() to be %v instead of %v", tt.val, tt.isLoopback, isLoopback) } @@ -681,8 +683,8 @@ func TestIP(t *testing.T) { t.Run(fmt.Sprintf("ip(%v).isMulticast()", tt.val), func(t *testing.T) { t.Parallel() val, err := ParseIPAddr(tt.val) - testutilOK(t, err) - isMulticast := val.isMulticast() + testutil.OK(t, err) + isMulticast := val.IsMulticast() if isMulticast != tt.isMulticast { t.Fatalf("expected ip(%v).isMulticast() to be %v instead of %v", tt.val, tt.isMulticast, isMulticast) } @@ -714,10 +716,10 @@ func TestIP(t *testing.T) { t.Run(fmt.Sprintf("ip(%v).contains(ip(%v))", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() lhs, err := ParseIPAddr(tt.lhs) - testutilOK(t, err) + testutil.OK(t, err) rhs, err := ParseIPAddr(tt.rhs) - testutilOK(t, err) - contains := lhs.contains(rhs) + testutil.OK(t, err) + contains := lhs.Contains(rhs) if contains != tt.contains { t.Fatalf("expected ip(%v).contains(ip(%v)) to be %v instead of %v", tt.lhs, tt.rhs, tt.contains, contains) } @@ -725,10 +727,10 @@ func TestIP(t *testing.T) { } }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - tn := IPAddr{}.typeName() - testutilEquals(t, tn, "IP") + tn := IPAddr{}.TypeName() + testutil.Equals(t, tn, "IP") }) } @@ -738,140 +740,140 @@ func TestDeepClone(t *testing.T) { t.Parallel() a := Boolean(true) b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) a = Boolean(false) - testutilEquals(t, a, Boolean(false)) - testutilEquals(t, b, Value(Boolean(true))) + testutil.Equals(t, a, Boolean(false)) + testutil.Equals(t, b, Value(Boolean(true))) }) t.Run("Long", func(t *testing.T) { t.Parallel() a := Long(42) b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) a = Long(43) - testutilEquals(t, a, Long(43)) - testutilEquals(t, b, Value(Long(42))) + testutil.Equals(t, a, Long(43)) + testutil.Equals(t, b, Value(Long(42))) }) t.Run("String", func(t *testing.T) { t.Parallel() a := String("cedar") b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) a = String("policy") - testutilEquals(t, a, String("policy")) - testutilEquals(t, b, Value(String("cedar"))) + testutil.Equals(t, a, String("policy")) + testutil.Equals(t, b, Value(String("cedar"))) }) t.Run("EntityUID", func(t *testing.T) { t.Parallel() a := NewEntityUID("Action", "test") b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) a.ID = "bananas" - testutilEquals(t, a, NewEntityUID("Action", "bananas")) - testutilEquals(t, b, Value(NewEntityUID("Action", "test"))) + testutil.Equals(t, a, NewEntityUID("Action", "bananas")) + testutil.Equals(t, b, Value(NewEntityUID("Action", "test"))) }) t.Run("Set", func(t *testing.T) { t.Parallel() a := Set{Long(42)} b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) a[0] = String("bananas") - testutilEquals(t, a, Set{String("bananas")}) - testutilEquals(t, b, Value(Set{Long(42)})) + testutil.Equals(t, a, Set{String("bananas")}) + testutil.Equals(t, b, Value(Set{Long(42)})) }) t.Run("NilSet", func(t *testing.T) { t.Parallel() var a Set b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) }) t.Run("Record", func(t *testing.T) { t.Parallel() a := Record{"key": Long(42)} b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) a["key"] = String("bananas") - testutilEquals(t, a, Record{"key": String("bananas")}) - testutilEquals(t, b, Value(Record{"key": Long(42)})) + testutil.Equals(t, a, Record{"key": String("bananas")}) + testutil.Equals(t, b, Value(Record{"key": Long(42)})) }) t.Run("NilRecord", func(t *testing.T) { t.Parallel() var a Record b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) }) t.Run("Decimal", func(t *testing.T) { t.Parallel() a := Decimal(42) b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) a = Decimal(43) - testutilEquals(t, a, Decimal(43)) - testutilEquals(t, b, Value(Decimal(42))) + testutil.Equals(t, a, Decimal(43)) + testutil.Equals(t, b, Value(Decimal(42))) }) t.Run("IPAddr", func(t *testing.T) { t.Parallel() a := mustIPValue("127.0.0.42") b := a.deepClone() - testutilEquals(t, a.Cedar(), b.Cedar()) + testutil.Equals(t, a.Cedar(), b.Cedar()) a = mustIPValue("127.0.0.43") - testutilEquals(t, a.Cedar(), mustIPValue("127.0.0.43").Cedar()) - testutilEquals(t, b.Cedar(), mustIPValue("127.0.0.42").Cedar()) + testutil.Equals(t, a.Cedar(), mustIPValue("127.0.0.43").Cedar()) + testutil.Equals(t, b.Cedar(), mustIPValue("127.0.0.42").Cedar()) }) } func TestPath(t *testing.T) { t.Parallel() - t.Run("equal", func(t *testing.T) { + t.Run("Equal", func(t *testing.T) { t.Parallel() - a := path("X") - b := path("X") - c := path("Y") - testutilEquals(t, a.equal(b), true) - testutilEquals(t, b.equal(a), true) - testutilEquals(t, a.equal(c), false) - testutilEquals(t, c.equal(a), false) + a := Path("X") + b := Path("X") + c := Path("Y") + testutil.Equals(t, a.Equal(b), true) + testutil.Equals(t, b.Equal(a), true) + testutil.Equals(t, a.Equal(c), false) + testutil.Equals(t, c.Equal(a), false) }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - a := path("X") - testutilEquals(t, a.typeName(), "(path of type `X`)") + a := Path("X") + testutil.Equals(t, a.TypeName(), "(Path of type `X`)") }) t.Run("String", func(t *testing.T) { t.Parallel() - a := path("X") - testutilEquals(t, a.String(), "X") + a := Path("X") + testutil.Equals(t, a.String(), "X") }) t.Run("Cedar", func(t *testing.T) { t.Parallel() - a := path("X") - testutilEquals(t, a.Cedar(), "X") + a := Path("X") + testutil.Equals(t, a.Cedar(), "X") }) t.Run("ExplicitMarshalJSON", func(t *testing.T) { t.Parallel() - a := path("X") + a := Path("X") v, err := a.ExplicitMarshalJSON() - testutilOK(t, err) - testutilEquals(t, string(v), `"X"`) + testutil.OK(t, err) + testutil.Equals(t, string(v), `"X"`) }) t.Run("deepClone", func(t *testing.T) { t.Parallel() - a := path("X") + a := Path("X") b := a.deepClone() - c, ok := b.(path) - testutilEquals(t, ok, true) - testutilEquals(t, c, a) + c, ok := b.(Path) + testutil.Equals(t, ok, true) + testutil.Equals(t, c, a) }) t.Run("pathFromSlice", func(t *testing.T) { t.Parallel() - a := pathFromSlice([]string{"X", "Y"}) - testutilEquals(t, a, path("X::Y")) + a := PathFromSlice([]string{"X", "Y"}) + testutil.Equals(t, a, Path("X::Y")) }) } diff --git a/x/exp/ast/ast_test.go b/x/exp/ast/ast_test.go index 6c4aaa4f..9e513e62 100644 --- a/x/exp/ast/ast_test.go +++ b/x/exp/ast/ast_test.go @@ -3,8 +3,8 @@ package ast_test import ( "testing" + "github.com/cedar-policy/cedar-go/types" "github.com/cedar-policy/cedar-go/x/exp/ast" - "github.com/cedar-policy/cedar-go/x/exp/types" ) // These tests mostly verify that policy ASTs compile diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index 38f2b7e4..7e261cea 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -1,9 +1,9 @@ package ast -type opType uint8 +type nodeType uint8 const ( - nodeTypeAccess opType = iota + nodeTypeAccess nodeType = iota nodeTypeAdd nodeTypeAnd nodeTypeAnnotation @@ -41,7 +41,7 @@ const ( ) type Node struct { - op opType + nodeType nodeType // TODO: Should we just have `value any`? args []Node value any diff --git a/x/exp/ast/operator.go b/x/exp/ast/operator.go index c3b89b21..d8b0c177 100644 --- a/x/exp/ast/operator.go +++ b/x/exp/ast/operator.go @@ -1,9 +1,9 @@ package ast -import "github.com/cedar-policy/cedar-go/x/exp/types" +import "github.com/cedar-policy/cedar-go/types" -// ____ _ -// / ___|___ _ __ ___ _ __ __ _ _ __(_)___ ___ _ __ +// ____ _ +// / ___|___ _ __ ___ _ __ __ _ _ __(_)___ ___ _ __ // | | / _ \| '_ ` _ \| '_ \ / _` | '__| / __|/ _ \| '_ \ // | |__| (_) | | | | | | |_) | (_| | | | \__ \ (_) | | | | // \____\___/|_| |_| |_| .__/ \__,_|_| |_|___/\___/|_| |_| @@ -152,6 +152,6 @@ func (lhs Node) IsInRange(rhs Node) Node { return newOpNode(nodeTypeIsInRange, lhs, rhs) } -func newOpNode(op opType, args ...Node) Node { - return Node{op: op, args: args} +func newOpNode(op nodeType, args ...Node) Node { + return Node{nodeType: op, args: args} } diff --git a/x/exp/ast/scope.go b/x/exp/ast/scope.go index 9a754e2c..05914236 100644 --- a/x/exp/ast/scope.go +++ b/x/exp/ast/scope.go @@ -1,6 +1,6 @@ package ast -import "github.com/cedar-policy/cedar-go/x/exp/types" +import "github.com/cedar-policy/cedar-go/types" func (p *Policy) PrincipalEq(entity types.EntityUID) *Policy { p.principal = Principal().Equals(Entity(entity)) @@ -16,7 +16,7 @@ func (p *Policy) PrincipalIn(entities ...types.EntityUID) *Policy { return p } -func (p *Policy) PrincipalIs(entityType types.EntityType) *Policy { +func (p *Policy) PrincipalIs(entityType string) *Policy { p.principal = Principal().Is(EntityType(entityType)) return p } @@ -49,7 +49,7 @@ func (p *Policy) ResourceIn(entities ...types.EntityUID) *Policy { return p } -func (p *Policy) ResourceIs(entityType types.EntityType) *Policy { +func (p *Policy) ResourceIs(entityType string) *Policy { p.principal = Resource().Is(EntityType(entityType)) return p } diff --git a/x/exp/ast/value.go b/x/exp/ast/value.go index d5d46f26..3d1f0f66 100644 --- a/x/exp/ast/value.go +++ b/x/exp/ast/value.go @@ -3,7 +3,7 @@ package ast import ( "fmt" - "github.com/cedar-policy/cedar-go/x/exp/types" + "github.com/cedar-policy/cedar-go/types" ) func Boolean(b types.Boolean) Node { @@ -76,7 +76,7 @@ func RecordNodes(nodes map[string]Node) Node { return newValueNode(nodeTypeRecord, nodes) } -func EntityType(e types.EntityType) Node { +func EntityType(e string) Node { return newValueNode(nodeTypeEntityType, e) } @@ -88,12 +88,12 @@ func Decimal(d types.Decimal) Node { return newValueNode(nodeTypeEntity, d) } -func IpAddr(i types.IpAddr) Node { +func IPAddr(i types.IPAddr) Node { return newValueNode(nodeTypeIpAddr, i) } -func newValueNode(op opType, v any) Node { - return Node{op: op, value: v} +func newValueNode(nodeType nodeType, v any) Node { + return Node{nodeType: nodeType, value: v} } func valueToNode(v types.Value) Node { @@ -112,8 +112,8 @@ func valueToNode(v types.Value) Node { return Entity(x) case types.Decimal: return Decimal(x) - case types.IpAddr: - return IpAddr(x) + case types.IPAddr: + return IPAddr(x) default: panic(fmt.Sprintf("unexpected value type: %T(%v)", v, v)) } diff --git a/x/exp/ast/variable.go b/x/exp/ast/variable.go index 5ef83687..8a7cb662 100644 --- a/x/exp/ast/variable.go +++ b/x/exp/ast/variable.go @@ -5,7 +5,7 @@ func Principal() Node { } func Action() Node { - return newPrincipalNode() + return newActionNode() } func Resource() Node { diff --git a/x/exp/types/types.go b/x/exp/types/types.go deleted file mode 100644 index bbf6c3aa..00000000 --- a/x/exp/types/types.go +++ /dev/null @@ -1,44 +0,0 @@ -package types - -import "net" - -type Value interface { - isValue() -} - -type Boolean bool - -func (Boolean) isValue() {} - -type String string - -func (String) isValue() {} - -type Long int64 - -func (Long) isValue() {} - -type Set []Value - -func (Set) isValue() {} - -type Record map[string]Value - -func (Record) isValue() {} - -type EntityType string - -type EntityUID struct { - Type string - ID string -} - -func (EntityUID) isValue() {} - -type Decimal []float64 - -func (Decimal) isValue() {} - -type IpAddr net.IPAddr - -func (IpAddr) isValue() {}