Skip to content

Commit

Permalink
fix: address various failing protocol tests (#1223)
Browse files Browse the repository at this point in the history
  • Loading branch information
lauzadis authored Jan 23, 2025
1 parent 0f8db44 commit 8b33693
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 63 deletions.
12 changes: 6 additions & 6 deletions codegen/protocol-tests/model/error-correction-tests.smithy
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ operation SayHelloXml { output: TestOutput, errors: [Error] }

structure TestOutputDocument with [TestStruct] {
innerField: Nested,
// FIXME: This trait fails smithy validator
// @required

// Note: This shape _should_ be @required, but causes Smithy httpResponseTests validation to fail.
// We expect `document` to be deserialized as `null` and enforce @required using a runtime check, but Smithy validator doesn't recognize / allow this.
document: Document
}
structure TestOutput with [TestStruct] { innerField: Nested }
Expand All @@ -65,8 +66,8 @@ structure TestStruct {
@required
nestedListValue: NestedList

// FIXME: This trait fails smithy validator
// @required
// Note: This shape _should_ be @required, but causes Smithy httpResponseTests validation to fail.
// We expect `nested` to be deserialized as `null` and enforce @required using a runtime check, but Smithy validator doesn't recognize / allow this.
nested: Nested

@required
Expand Down Expand Up @@ -97,8 +98,7 @@ union MyUnion {
}

structure Nested {
// FIXME: This trait fails smithy validator
// @required
@required
a: String
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,7 @@ abstract class AwsHttpBindingProtocolGenerator : HttpBindingProtocolGenerator()
// val targetedTest = TestMemberDelta(setOf("RestJsonComplexErrorWithNoMessage"), TestContainmentMode.RUN_TESTS)

val ignoredTests = TestMemberDelta(
setOf(
"AwsJson10ClientErrorCorrectsWithDefaultValuesWhenServerFailsToSerializeRequiredValues",
"RestJsonNullAndEmptyHeaders",
"NullAndEmptyHeaders",
"RpcV2CborClientPopulatesDefaultsValuesWhenMissingInResponse",
"RpcV2CborClientPopulatesDefaultValuesInInput",
),
setOf(),
)

val requestTestBuilder = HttpProtocolUnitTestRequestGenerator.Builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ class KotlinSymbolProvider(private val model: Model, private val settings: Kotli
} else {
// only use @default if type is `T`
shape.getTrait<DefaultTrait>()?.let {
defaultValue(it.getDefaultValue(targetShape))
setDefaultValue(it, targetShape)
}
}
}
Expand All @@ -219,9 +219,10 @@ class KotlinSymbolProvider(private val model: Model, private val settings: Kotli
}
}

private fun DefaultTrait.getDefaultValue(targetShape: Shape): String? {
val node = toNode()
return when {
private fun Symbol.Builder.setDefaultValue(defaultTrait: DefaultTrait, targetShape: Shape) {
val node = defaultTrait.toNode()

val defaultValue = when {
node.toString() == "null" -> null

// Check if target is an enum before treating the default like a regular number/string
Expand All @@ -235,13 +236,20 @@ class KotlinSymbolProvider(private val model: Model, private val settings: Kotli
"${enumSymbol.fullName}.fromValue($arg)"
}

targetShape.isBlobShape && targetShape.isStreaming ->
node
.toString()
.takeUnless { it.isEmpty() }
?.let { "ByteStream.fromString(${it.dq()})" }
targetShape.isBlobShape -> {
addReferences(RuntimeTypes.Core.Text.Encoding.decodeBase64)

targetShape.isBlobShape -> "${node.toString().dq()}.encodeToByteArray()"
if (targetShape.isStreaming) {
node.toString()
.takeUnless { it.isEmpty() }
?.let {
addReferences(RuntimeTypes.Core.Content.ByteStream)
"ByteStream.fromString(${it.dq()}.decodeBase64())"
}
} else {
"${node.toString().dq()}.decodeBase64().encodeToByteArray()"
}
}

targetShape.isDocumentShape -> getDefaultValueForDocument(node)
targetShape.isTimestampShape -> getDefaultValueForTimestamp(node.asNumberNode().get())
Expand All @@ -252,6 +260,8 @@ class KotlinSymbolProvider(private val model: Model, private val settings: Kotli
node.isStringNode -> node.toString().dq()
else -> node.toString()
}

defaultValue(defaultValue)
}

private fun getDefaultValueForTimestamp(node: NumberNode): String {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ class StructureGenerator(
} else {
memberSymbol
}

write("public var #L: #E", memberName, builderMemberSymbol)
}
write("")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,8 @@ class HttpStringValuesMapSerializer(
val paramName = binding.locationName
// addAll collection parameter 2
val param2 = if (mapFnContents.isEmpty()) "input.$memberName" else "input.$memberName.map { $mapFnContents }"
val nullCheck = if (memberSymbol.isNullable) "?" else ""
writer.write(
"if (input.#L$nullCheck.isNotEmpty() == true) #L(#S, #L)",
"if (input.#L != null) #L(#S, #L)",
memberName,
binding.location.addAllFnName,
paramName,
Expand All @@ -174,8 +173,7 @@ class HttpStringValuesMapSerializer(
val paramName = binding.locationName
val memberSymbol = symbolProvider.toSymbol(binding.member)

// NOTE: query parameters are allowed to be empty, whereas headers should omit empty string
// values from serde
// NOTE: query parameters are allowed to be empty
if ((location == HttpBinding.Location.QUERY || location == HttpBinding.Location.HEADER) && binding.member.hasTrait<IdempotencyTokenTrait>()) {
// Call the idempotency token function if no supplied value.
writer.addImport(RuntimeTypes.SmithyClient.IdempotencyTokenProviderExt)
Expand All @@ -185,18 +183,7 @@ class HttpStringValuesMapSerializer(
paramName,
)
} else {
val nullCheck =
if (location == HttpBinding.Location.QUERY ||
memberTarget.hasTrait<
@Suppress("DEPRECATION")
software.amazon.smithy.model.traits.EnumTrait,
>()
) {
if (memberSymbol.isNullable) "input.$memberName != null" else ""
} else {
val nullCheck = if (memberSymbol.isNullable) "?" else ""
"input.$memberName$nullCheck.isNotEmpty() == true"
}
val nullCheck = if (memberSymbol.isNullable) "input.$memberName != null" else ""

val cond = defaultCheck(binding.member) ?: nullCheck

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class SymbolProviderTest {
"double,2.71828,2.71828",
"byte,10,10.toByte()",
"string,\"hello\",\"hello\"",
"blob,\"abcdefg\",\"abcdefg\".encodeToByteArray()",
"blob,\"abcdefg\",\"abcdefg\".decodeBase64().encodeToByteArray()",
"boolean,true,true",
"bigInteger,5,5",
"bigDecimal,9.0123456789,9.0123456789",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ internal class SmokeTestOperationSerializer: HttpSerializer.NonStreaming<SmokeTe
}
builder.headers {
if (input.header1?.isNotEmpty() == true) append("X-Header1", input.header1)
if (input.header2?.isNotEmpty() == true) append("X-Header2", input.header2)
if (input.header1 != null) append("X-Header1", input.header1)
if (input.header2 != null) append("X-Header2", input.header2)
}
val payload = serializeSmokeTestOperationBody(context, input)
Expand Down Expand Up @@ -264,7 +264,7 @@ internal class TimestampInputOperationSerializer: HttpSerializer.NonStreaming<Ti
}
parameters.decodedParameters(PercentEncoding.SmithyLabel) {
if (input.queryTimestamp != null) add("qtime", input.queryTimestamp.format(TimestampFormat.ISO_8601))
if (input.queryTimestampList?.isNotEmpty() == true) addAll("qtimeList", input.queryTimestampList.map { it.format(TimestampFormat.ISO_8601) })
if (input.queryTimestampList != null) addAll("qtimeList", input.queryTimestampList.map { it.format(TimestampFormat.ISO_8601) })
}
}
Expand Down Expand Up @@ -304,7 +304,7 @@ internal class BlobInputOperationSerializer: HttpSerializer.NonStreaming<BlobInp
}
builder.headers {
if (input.headerMediaType?.isNotEmpty() == true) append("X-Blob", input.headerMediaType.encodeBase64())
if (input.headerMediaType != null) append("X-Blob", input.headerMediaType.encodeBase64())
}
val payload = serializeBlobInputOperationBody(context, input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ class HttpStringValuesMapSerializerTest {
contents.assertBalancedBracesAndParens()

val expectedContents = """
if (input.header1?.isNotEmpty() == true) append("X-Header1", input.header1)
if (input.header2?.isNotEmpty() == true) append("X-Header2", input.header2)
if (input.header1 != null) append("X-Header1", input.header1)
if (input.header2 != null) append("X-Header2", input.header2)
""".trimIndent()
contents.shouldContainOnlyOnceWithDiff(expectedContents)
}
Expand All @@ -157,7 +157,7 @@ class HttpStringValuesMapSerializerTest {
contents.assertBalancedBracesAndParens()

val expectedContents = """
if (input.headerMediaType?.isNotEmpty() == true) append("X-Blob", input.headerMediaType.encodeBase64())
if (input.headerMediaType != null) append("X-Blob", input.headerMediaType.encodeBase64())
""".trimIndent()
contents.shouldContainOnlyOnceWithDiff(expectedContents)
}
Expand All @@ -168,10 +168,10 @@ class HttpStringValuesMapSerializerTest {
contents.assertBalancedBracesAndParens()

val expectedContents = """
if (input.enumList?.isNotEmpty() == true) appendAll("x-enumList", input.enumList.map { quoteHeaderValue(it.value) })
if (input.intList?.isNotEmpty() == true) appendAll("x-intList", input.intList.map { it.toString() })
if (input.strList?.isNotEmpty() == true) appendAll("x-strList", input.strList.map { quoteHeaderValue(it) })
if (input.tsList?.isNotEmpty() == true) appendAll("x-tsList", input.tsList.map { it.format(TimestampFormat.RFC_5322) })
if (input.enumList != null) appendAll("x-enumList", input.enumList.map { quoteHeaderValue(it.value) })
if (input.intList != null) appendAll("x-intList", input.intList.map { it.toString() })
if (input.strList != null) appendAll("x-strList", input.strList.map { quoteHeaderValue(it) })
if (input.tsList != null) appendAll("x-tsList", input.tsList.map { it.format(TimestampFormat.RFC_5322) })
""".trimIndent()
contents.shouldContainOnlyOnceWithDiff(expectedContents)
}
Expand All @@ -190,7 +190,7 @@ class HttpStringValuesMapSerializerTest {
val queryContents = getTestContents(defaultModel, "com.test#TimestampInput", HttpBinding.Location.QUERY)
val expectedQueryContents = """
if (input.queryTimestamp != null) add("qtime", input.queryTimestamp.format(TimestampFormat.ISO_8601))
if (input.queryTimestampList?.isNotEmpty() == true) addAll("qtimeList", input.queryTimestampList.map { it.format(TimestampFormat.ISO_8601) })
if (input.queryTimestampList != null) addAll("qtimeList", input.queryTimestampList.map { it.format(TimestampFormat.ISO_8601) })
""".trimIndent()
queryContents.shouldContainOnlyOnceWithDiff(expectedQueryContents)
}
Expand Down
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ micrometer-version = "1.14.2"
binary-compatibility-validator-version = "0.16.3"

# codegen
smithy-version = "1.53.0"
smithy-version = "1.54.0"
smithy-gradle-version = "0.9.0"

# testing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,17 @@ public fun String.decodeBase64Bytes(): ByteArray = encodeToByteArray().decodeBas
* Decode [ByteArray] from base64 format
*/
public fun ByteArray.decodeBase64(): ByteArray {
val encoded = this
// Calculate the padding needed to make the length a multiple of 4
val remainder = size % 4
val encoded: ByteArray = if (remainder == 0) {
this
} else {
this + ByteArray(4 - remainder) { BASE64_PAD.code.toByte() }
}

val decodedLen = base64DecodedLen(encoded)
val decoded = ByteArray(decodedLen)
val blockCnt = size / 4
val blockCnt = encoded.size / 4
var readIdx = 0
var writeIdx = 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,6 @@ class Base64Test {
ex.message!!.shouldContain("decode base64: invalid input byte: 45")
}

@Test
fun decodeNonMultipleOf4() {
val ex = assertFails {
"Zm9vY=".decodeBase64()
}
ex.message!!.shouldContain("invalid base64 string of length 6; not a multiple of 4")
}

@Test
fun decodeInvalidPadding() {
val ex = assertFails {
Expand Down Expand Up @@ -116,4 +108,24 @@ class Base64Test {
assertEquals(encoded, decoded.encodeBase64())
assertEquals(decoded, encoded.decodeBase64())
}

@Test
fun testUnpaddedInputs() {
// from https://github.com/smithy-lang/smithy/pull/2502
val input = "v2hkZWZhdWx0c79tZGVmYXVsdFN0cmluZ2JoaW5kZWZhdWx0Qm9vbGVhbvVrZGVmYXVsdExpc3Sf/3BkZWZhdWx0VGltZXN0YW1wwQBrZGVmYXVsdEJsb2JDYWJja2RlZmF1bHRCeXRlAWxkZWZhdWx0U2hvcnQBbmRlZmF1bHRJbnRlZ2VyCmtkZWZhdWx0TG9uZxhkbGRlZmF1bHRGbG9hdPo/gAAAbWRlZmF1bHREb3VibGX6P4AAAGpkZWZhdWx0TWFwv/9rZGVmYXVsdEVudW1jRk9PbmRlZmF1bHRJbnRFbnVtAWtlbXB0eVN0cmluZ2BsZmFsc2VCb29sZWFu9GllbXB0eUJsb2JAaHplcm9CeXRlAGl6ZXJvU2hvcnQAa3plcm9JbnRlZ2VyAGh6ZXJvTG9uZwBpemVyb0Zsb2F0+gAAAABqemVyb0RvdWJsZfoAAAAA//8"
input.decodeBase64()

val inputs = mapOf<String, String>(
"YQ" to "a",
"Yg" to "b",
"YWI" to "ab",
"YWJj" to "abc",
"SGVsbG8gd29ybGQ" to "Hello world",
)

inputs.forEach { (encoded, expected) ->
val actual = encoded.decodeBase64()
assertEquals(expected, actual)
}
}
}

0 comments on commit 8b33693

Please sign in to comment.