diff --git a/type.go b/type.go index f4d03f99a..2e6936cd9 100644 --- a/type.go +++ b/type.go @@ -88,6 +88,15 @@ func (n binaryNode) Type(table typesTable) (Type, error) { } return nil, fmt.Errorf(`invalid operation: %v (mismatched types %v and %v)`, n, ltype, rtype) + case "in", "not in": + if (isStringType(ltype) || isInterfaceType(ltype)) && (isStructType(rtype) || isInterfaceType(rtype)) { + return boolType, nil + } + if isArrayType(rtype) || isMapType(rtype) || isInterfaceType(rtype) { + return boolType, nil + } + return nil, fmt.Errorf(`invalid operation: %v (mismatched types %v and %v)`, n, ltype, rtype) + case "<", ">", ">=", "<=": if (isNumberType(ltype) || isInterfaceType(ltype)) && (isNumberType(rtype) || isInterfaceType(rtype)) { return boolType, nil @@ -325,6 +334,39 @@ func isStringType(t Type) bool { return false } +func isArrayType(t Type) bool { + t = dereference(t) + if t != nil { + switch t.Kind() { + case reflect.Slice, reflect.Array: + return true + } + } + return false +} + +func isMapType(t Type) bool { + t = dereference(t) + if t != nil { + switch t.Kind() { + case reflect.Map: + return true + } + } + return false +} + +func isStructType(t Type) bool { + t = dereference(t) + if t != nil { + switch t.Kind() { + case reflect.Struct: + return true + } + } + return false +} + func fieldType(ntype Type, name string) (Type, bool) { ntype = dereference(ntype) if ntype != nil { diff --git a/type_test.go b/type_test.go index e902bdfcc..46eb0d853 100644 --- a/type_test.go +++ b/type_test.go @@ -52,6 +52,10 @@ var typeTests = []typeTest{ "nil == nil", "nil == IntPtr", "Foo2p.Bar.Baz", + "Str in Foo", + "Str in Arr", + "nil in Arr", + "Str not in Foo2p", "Int | Num", "Int ^ Num", "Int & Num", @@ -268,6 +272,14 @@ var typeErrorTests = []typeErrorTest{ "NilFn() and OkFn()", "invalid operation: NilFn() and OkFn() (mismatched types and bool)", }, + { + "'str' in Str", + `invalid operation: "str" in Str (mismatched types string and string)`, + }, + { + "1 in Foo", + "invalid operation: 1 in Foo (mismatched types float64 and *expr_test.foo)", + }, } type abc interface {