Skip to content

Commit

Permalink
coerce nil numbers to 0 in jmespath codegen (#565)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucix-aws authored Jan 23, 2025
1 parent f2ae388 commit d708d1d
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package software.amazon.smithy.go.codegen;

import static software.amazon.smithy.go.codegen.GoWriter.emptyGoTemplate;
import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;
import static software.amazon.smithy.go.codegen.SymbolUtils.isNilable;
import static software.amazon.smithy.go.codegen.SymbolUtils.isPointable;
Expand Down Expand Up @@ -407,11 +408,48 @@ private GoWriter.Writable compareVariables(String ident, Variable left, Variable
return goTemplate("$1L := $5L($2L) $4L $5L($3L)", ident, left.ident, right.ident, cmp, cast);
}

// undocumented jmespath behavior: null in numeric _ordering_ comparisons coerces to 0
// this means the subsequent nil checks for numerics are moot, but it's either this or branch the codegen even
// further for questionable benefit
var nilCoerceLeft = emptyGoTemplate();
var nilCoerceRight = emptyGoTemplate();
if (isOrderComparator(cmp)) {
if (isLPtr && left.shape instanceof NumberShape) {
nilCoerceLeft = goTemplate("""
if ($1L == nil) {
$1L = new($2T)
*$1L = 0
}""", left.ident, left.type);
}
if (isRPtr && right.shape instanceof NumberShape) {
nilCoerceRight = goTemplate("""
if ($1L == nil) {
$1L = new($2T)
*$1L = 0
}""", right.ident, right.type);
}
}

// also, if they're both pointers, and it's (in)equality, there's an additional true case where both are nil,
// or both are different
var elseCheckPtrs = emptyGoTemplate();
if (isLPtr && isRPtr) {
if (cmp == ComparatorType.EQUAL) {
elseCheckPtrs = goTemplate("else { $L = $L == nil && $L == nil }",
ident, left.ident, right.ident);
} else if (cmp == ComparatorType.NOT_EQUAL) {
elseCheckPtrs = goTemplate("else { $1L = ($2L == nil && $3L != nil) || ($2L != nil && $3L == nil) }",
ident, left.ident, right.ident);
}
}

return goTemplate("""
var $ident:L bool
$nilCoerceLeft:W
$nilCoerceRight:W
if $lif:L $amp:L $rif:L {
$ident:L = $cast:L($lhs:L) $cmp:L $cast:L($rhs:L)
}""",
}$elseCheckPtrs:W""",
Map.of(
"ident", ident,
"lif", isLPtr ? left.ident + " != nil" : "",
Expand All @@ -420,10 +458,20 @@ private GoWriter.Writable compareVariables(String ident, Variable left, Variable
"cmp", cmp,
"lhs", isLPtr ? "*" + left.ident : left.ident,
"rhs", isRPtr ? "*" + right.ident : right.ident,
"cast", cast
"cast", cast,
"nilCoerceLeft", nilCoerceLeft,
"nilCoerceRight", nilCoerceRight
),
Map.of(
"elseCheckPtrs", elseCheckPtrs
));
}

private static boolean isOrderComparator(ComparatorType cmp) {
return cmp == ComparatorType.GREATER_THAN || cmp == ComparatorType.LESS_THAN
|| cmp == ComparatorType.GREATER_THAN_EQUAL || cmp == ComparatorType.LESS_THAN_EQUAL;
}

/**
* Represents a variable (input, intermediate, or final output) of a JMESPath traversal.
* @param shape The underlying shape referenced by this variable. For certain jmespath expressions (e.g.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ public class GoJmespathExpressionGeneratorTest {
objectList: ObjectList
objectMap: ObjectMap
nested: NestedStruct
nullableIntegerA: Integer
nullableIntegerB: Integer
}
structure Object {
Expand Down Expand Up @@ -318,6 +320,7 @@ public void testComparatorStringLHSNil() {
}
v4 := "foo"
var v5 bool
if v2 != nil {
v5 = string(*v2) == string(v4)
}
Expand Down Expand Up @@ -345,6 +348,7 @@ public void testComparatorStringRHSNil() {
v3 = v4
}
var v5 bool
if v3 != nil {
v5 = string(v1) == string(*v3)
}
Expand Down Expand Up @@ -372,9 +376,10 @@ public void testComparatorStringBothNil() {
}
v4 := input.SimpleShape
var v5 bool
if v2 != nil && v4 != nil {
v5 = string(*v2) == string(*v4)
}
}else { v5 = v2 == nil && v4 == nil }
"""));
}

Expand Down Expand Up @@ -546,4 +551,107 @@ public void testMultiSelectFlatten() {
}
"""));
}

@Test
public void testOrderComparatorNumberCoercesLeftNullable() {
var expr = "nullableIntegerA > `9`";

var writer = testWriter();
var generator = new GoJmespathExpressionGenerator(testContext(), writer);
var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable(
TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")),
"input"
));
assertThat(actual.shape().toShapeId().toString(), Matchers.equalTo("smithy.api#PrimitiveBoolean"));
assertThat(actual.ident(), Matchers.equalTo("v3"));
assertThat(writer.toString(), Matchers.containsString("""
v1 := input.NullableIntegerA
v2 := 9
var v3 bool
if (v1 == nil) {
v1 = new(int32)
*v1 = 0
}
if v1 != nil {
v3 = int64(*v1) > int64(v2)
}
"""));
}

@Test
public void testOrderComparatorNumberCoercesBothNullable() {
var expr = "nullableIntegerA > nullableIntegerB";

var writer = testWriter();
var generator = new GoJmespathExpressionGenerator(testContext(), writer);
var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable(
TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")),
"input"
));
assertThat(actual.shape().toShapeId().toString(), Matchers.equalTo("smithy.api#PrimitiveBoolean"));
assertThat(actual.ident(), Matchers.equalTo("v3"));
assertThat(writer.toString(), Matchers.containsString("""
v1 := input.NullableIntegerA
v2 := input.NullableIntegerB
var v3 bool
if (v1 == nil) {
v1 = new(int32)
*v1 = 0
}
if (v2 == nil) {
v2 = new(int32)
*v2 = 0
}
if v1 != nil && v2 != nil {
v3 = int64(*v1) > int64(*v2)
}
"""));
}

@Test
public void testEqualBothNullable() {
var expr = "nullableIntegerA == nullableIntegerB";

var writer = testWriter();
var generator = new GoJmespathExpressionGenerator(testContext(), writer);
var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable(
TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")),
"input"
));
assertThat(actual.shape().toShapeId().toString(), Matchers.equalTo("smithy.api#PrimitiveBoolean"));
assertThat(actual.ident(), Matchers.equalTo("v3"));
assertThat(writer.toString(), Matchers.containsString("""
v1 := input.NullableIntegerA
v2 := input.NullableIntegerB
var v3 bool
if v1 != nil && v2 != nil {
v3 = int64(*v1) == int64(*v2)
}else { v3 = v1 == nil && v2 == nil }
"""));
}

@Test
public void testNotEqualBothNullable() {
var expr = "nullableIntegerA != nullableIntegerB";

var writer = testWriter();
var generator = new GoJmespathExpressionGenerator(testContext(), writer);
var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable(
TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")),
"input"
));
assertThat(actual.shape().toShapeId().toString(), Matchers.equalTo("smithy.api#PrimitiveBoolean"));
assertThat(actual.ident(), Matchers.equalTo("v3"));
assertThat(writer.toString(), Matchers.containsString("""
v1 := input.NullableIntegerA
v2 := input.NullableIntegerB
var v3 bool
if v1 != nil && v2 != nil {
v3 = int64(*v1) != int64(*v2)
}else { v3 = (v1 == nil && v2 != nil) || (v1 != nil && v2 == nil) }
"""));
}
}

0 comments on commit d708d1d

Please sign in to comment.