diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java index 840af016..89a87f62 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java @@ -430,10 +430,18 @@ private GoWriter.Writable compareVariables(String ident, Variable left, Variable } } - // also, if they're both pointers, and it's equality, there's an additional true case where both are nil - var elseCmpBothNull = !isOrderComparator(cmp) && isLPtr && isRPtr - ? goTemplate("else { $L = $L == nil && $L == nil }", ident, left.ident, right.ident) - : emptyGoTemplate(); + // also, if they're both pointers, and it's (in)equality, there's an additional true case where both are nil, + // or both are different + var elseCheckPtrs = emptyGoTemplate(); + if (isLPtr && isRPtr) { + if (cmp == ComparatorType.EQUAL) { + elseCheckPtrs = goTemplate("else { $L = $L == nil && $L == nil }", + ident, left.ident, right.ident); + } else if (cmp == ComparatorType.NOT_EQUAL) { + elseCheckPtrs = goTemplate("else { $1L = ($2L == nil && $3L != nil) || ($2L != nil && $3L == nil) }", + ident, left.ident, right.ident); + } + } return goTemplate(""" var $ident:L bool @@ -441,7 +449,7 @@ private GoWriter.Writable compareVariables(String ident, Variable left, Variable $nilCoerceRight:W if $lif:L $amp:L $rif:L { $ident:L = $cast:L($lhs:L) $cmp:L $cast:L($rhs:L) - }$elseCmpBothNull:W""", + }$elseCheckPtrs:W""", Map.of( "ident", ident, "lif", isLPtr ? left.ident + " != nil" : "", @@ -455,7 +463,7 @@ private GoWriter.Writable compareVariables(String ident, Variable left, Variable "nilCoerceRight", nilCoerceRight ), Map.of( - "elseCmpBothNull", elseCmpBothNull + "elseCheckPtrs", elseCheckPtrs )); } diff --git a/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java index a4b96e76..0a6b01a3 100644 --- a/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java +++ b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java @@ -608,4 +608,50 @@ public void testOrderComparatorNumberCoercesBothNullable() { } """)); } + + @Test + public void testEqualBothNullable() { + var expr = "nullableIntegerA == nullableIntegerB"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape().toShapeId().toString(), Matchers.equalTo("smithy.api#PrimitiveBoolean")); + assertThat(actual.ident(), Matchers.equalTo("v3")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.NullableIntegerA + v2 := input.NullableIntegerB + var v3 bool + + if v1 != nil && v2 != nil { + v3 = int64(*v1) == int64(*v2) + }else { v3 = v1 == nil && v2 == nil } + """)); + } + + @Test + public void testNotEqualBothNullable() { + var expr = "nullableIntegerA != nullableIntegerB"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape().toShapeId().toString(), Matchers.equalTo("smithy.api#PrimitiveBoolean")); + assertThat(actual.ident(), Matchers.equalTo("v3")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.NullableIntegerA + v2 := input.NullableIntegerB + var v3 bool + + if v1 != nil && v2 != nil { + v3 = int64(*v1) != int64(*v2) + }else { v3 = (v1 == nil && v2 != nil) || (v1 != nil && v2 == nil) } + """)); + } }