Skip to content

Commit

Permalink
Enable comparison feature like python
Browse files Browse the repository at this point in the history
  • Loading branch information
ckganesan committed May 31, 2024
1 parent 7e6e6f5 commit 6bbe56c
Show file tree
Hide file tree
Showing 13 changed files with 665 additions and 494 deletions.
8 changes: 8 additions & 0 deletions ast/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,11 @@ type PairNode struct {
Key Node // Key of the pair.
Value Node // Value of the pair.
}

// CompareNode represents comparison
type CompareNode struct {
base
Left Node // Left represents the left-hand side of the comparison operation
Operators []string // Operators is a list of comparison operator tokens used in the comparison.
Comparators []Node // Comparators representing the right-hand sides of the comparison operation
}
31 changes: 31 additions & 0 deletions ast/print.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,34 @@ func (n *PairNode) String() string {
}
return fmt.Sprintf("(%s): %s", n.Key.String(), n.Value.String())
}

func (n *CompareNode) string(node Node) string {
switch v := node.(type) {
case *BinaryNode, *CompareNode:
return fmt.Sprintf("(%s)", v)
default:
return v.String()
}
}

func (n *CompareNode) String() string {
var builder strings.Builder
builder.WriteString(n.string(n.Left))
opIdx := 0
for i := 0; i < len(n.Comparators); i++ {
if op := n.Operators[opIdx]; op != "&&" {
builder.WriteByte(' ')
builder.WriteString(op)
if op == "not" {
opIdx++
builder.WriteByte(' ')
builder.WriteString(n.Operators[opIdx])
}
builder.WriteByte(' ')
builder.WriteString(n.string(n.Comparators[i]))
}
opIdx++
}

return builder.String()
}
2 changes: 1 addition & 1 deletion ast/print_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestPrint(t *testing.T) {
{`a == b`, `a == b`},
{`a matches b`, `a matches b`},
{`a in b`, `a in b`},
{`a not in b`, `not (a in b)`},
{`a not in b`, `a not in b`},
{`a and b`, `a and b`},
{`a or b`, `a or b`},
{`a or b and c`, `a or (b and c)`},
Expand Down
5 changes: 5 additions & 0 deletions ast/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ func Walk(node *Node, v Visitor) {
case *PairNode:
Walk(&n.Key, v)
Walk(&n.Value, v)
case *CompareNode:
Walk(&n.Left, v)
for i := range n.Comparators {
Walk(&n.Comparators[i], v)
}
default:
panic(fmt.Sprintf("undefined node type (%T)", node))
}
Expand Down
170 changes: 95 additions & 75 deletions checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ func (v *checker) visit(node ast.Node) (reflect.Type, info) {
t, i = v.MapNode(n)
case *ast.PairNode:
t, i = v.PairNode(n)
case *ast.CompareNode:
t, i = v.CompareNode(n)
default:
panic(fmt.Sprintf("undefined node type (%T)", node))
}
Expand Down Expand Up @@ -272,17 +274,12 @@ func (v *checker) UnaryNode(node *ast.UnaryNode) (reflect.Type, info) {

func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
l, _ := v.visit(node.Left)
r, ri := v.visit(node.Right)
r, _ := v.visit(node.Right)

l = deref.Type(l)
r = deref.Type(r)

switch node.Operator {
case "==", "!=":
if isComparable(l, r) {
return boolType, info{}
}

case "or", "||", "and", "&&":
if isBool(l) && isBool(r) {
return boolType, info{}
Expand All @@ -291,20 +288,6 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
return boolType, info{}
}

case "<", ">", ">=", "<=":
if isNumber(l) && isNumber(r) {
return boolType, info{}
}
if isString(l) && isString(r) {
return boolType, info{}
}
if isTime(l) && isTime(r) {
return boolType, info{}
}
if or(l, r, isNumber, isString, isTime) {
return boolType, info{}
}

case "-":
if isNumber(l) && isNumber(r) {
return combined(l, r), info{}
Expand Down Expand Up @@ -368,60 +351,6 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
return anyType, info{}
}

case "in":
if (isString(l) || isAny(l)) && isStruct(r) {
return boolType, info{}
}
if isMap(r) {
if l == nil { // It is possible to compare with nil.
return boolType, info{}
}
if !isAny(l) && !l.AssignableTo(r.Key()) {
return v.error(node, "cannot use %v as type %v in map key", l, r.Key())
}
return boolType, info{}
}
if isArray(r) {
if l == nil { // It is possible to compare with nil.
return boolType, info{}
}
if !isComparable(l, r.Elem()) {
return v.error(node, "cannot use %v as type %v in array", l, r.Elem())
}
if !isComparable(l, ri.elem) {
return v.error(node, "cannot use %v as type %v in array", l, ri.elem)
}
return boolType, info{}
}
if isAny(l) && anyOf(r, isString, isArray, isMap) {
return boolType, info{}
}
if isAny(r) {
return boolType, info{}
}

case "matches":
if s, ok := node.Right.(*ast.StringNode); ok {
_, err := regexp.Compile(s.Value)
if err != nil {
return v.error(node, err.Error())
}
}
if isString(l) && isString(r) {
return boolType, info{}
}
if or(l, r, isString) {
return boolType, info{}
}

case "contains", "startsWith", "endsWith":
if isString(l) && isString(r) {
return boolType, info{}
}
if or(l, r, isString) {
return boolType, info{}
}

case "..":
ret := reflect.SliceOf(integerType)
if isInteger(l) && isInteger(r) {
Expand All @@ -448,7 +377,6 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {

default:
return v.error(node, "unknown operator (%v)", node.Operator)

}

return v.error(node, `invalid operation: %v (mismatched types %v and %v)`, node.Operator, l, r)
Expand Down Expand Up @@ -1207,3 +1135,95 @@ func (v *checker) PairNode(node *ast.PairNode) (reflect.Type, info) {
v.visit(node.Value)
return nilType, info{}
}

func (v *checker) CompareNode(node *ast.CompareNode) (reflect.Type, info) {
nodeLeft := node.Left
opIdx := 0
operatorOverride := false
for i, comparator := range node.Comparators {
op := node.Operators[opIdx]
if negate := op == "not"; negate {
opIdx++
op = node.Operators[opIdx]
}
if op == "&&" {
if !operatorOverride {
operatorOverride = true
}
} else if err := v.compareNode(op, nodeLeft, comparator, i); err != nil {
return v.error(comparator, err.Error())
}
opIdx++
nodeLeft = comparator
}
if operatorOverride {
return anyType, info{}
}
return boolType, info{}
}

func (v *checker) compareNode(op string, nodeLeft, nodeRight ast.Node, index int) error {
l, _ := v.visit(nodeLeft)
r, ri := v.visit(nodeRight)
l = deref.Type(l)
r = deref.Type(r)
switch op {
case "==", "!=":
if (isBool(r) && index > 0) || isComparable(l, r) {
return nil
}
case "<", ">", ">=", "<=":
if isNumber(l) && isNumber(r) ||
isString(l) && isString(r) ||
isTime(l) && isTime(r) ||
or(l, r, isNumber, isString, isTime) {
return nil
}
case "in":
if (isString(l) || isAny(l)) && isStruct(r) {
return nil
}
if isMap(r) {
if l == nil { // It is possible to compare with nil.
return nil
}
if !isAny(l) && !l.AssignableTo(r.Key()) {
return fmt.Errorf("cannot use %v as type %v in map key", l, r.Key())
}
return nil
}
if isArray(r) {
if l == nil { // It is possible to compare with nil.
return nil
}
if !isComparable(l, r.Elem()) {
return fmt.Errorf("cannot use %v as type %v in array", l, r.Elem())
}
if !isComparable(l, ri.elem) {
return fmt.Errorf("cannot use %v as type %v in array", l, ri.elem)
}
return nil
}
if (isAny(l) && anyOf(r, isString, isArray, isMap)) || isAny(r) {
return nil
}

case "matches":
if s, ok := nodeRight.(*ast.StringNode); ok {
if _, err := regexp.Compile(s.Value); err != nil {
return err
}
}
if (isString(l) && isString(r)) || or(l, r, isString) {
return nil
}
case "contains", "startsWith", "endsWith":
if isString(l) && isString(r) ||
or(l, r, isString) {
return nil
}
default:
return fmt.Errorf("unknown operator (%v)", op)
}
return fmt.Errorf(`invalid operation: %v (mismatched types %v and %v)`, op, l, r)
}
Loading

0 comments on commit 6bbe56c

Please sign in to comment.