Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aws_base64_compute_encoded_len() is now exact, doesn't add 1 extra for null-terminator #1188

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
4 changes: 2 additions & 2 deletions include/aws/common/encoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ int aws_hex_compute_encoded_len(size_t to_encode_len, size_t *encoded_length);

/*
* Base 16 (hex) encodes the contents of to_encode and stores the result in
* output. 0 terminates the result. Assumes the buffer is empty and does not resize on
* output. Assumes the buffer is empty and does not resize on
* insufficient capacity.
*/
AWS_COMMON_API
int aws_hex_encode(const struct aws_byte_cursor *AWS_RESTRICT to_encode, struct aws_byte_buf *AWS_RESTRICT output);

/*
* Base 16 (hex) encodes the contents of to_encode and appends the result in
* output. Does not 0-terminate. Grows the destination buffer dynamically if necessary.
* output. Grows the destination buffer dynamically if necessary.
*/
AWS_COMMON_API
int aws_hex_encode_append_dynamic(
Expand Down
53 changes: 27 additions & 26 deletions source/encoding.c
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ static const uint8_t BASE64_DECODING_TABLE[256] = {
int aws_hex_compute_encoded_len(size_t to_encode_len, size_t *encoded_length) {
AWS_ASSERT(encoded_length);

size_t temp = (to_encode_len << 1) + 1;
/* For every byte of input, there will be 2 hex chars of encoded output */

size_t temp = to_encode_len << 1;

if (AWS_UNLIKELY(temp < to_encode_len)) {
graebm marked this conversation as resolved.
Show resolved Hide resolved
return aws_raise_error(AWS_ERROR_OVERFLOW_DETECTED);
Expand Down Expand Up @@ -98,7 +100,7 @@ int aws_hex_encode(const struct aws_byte_cursor *AWS_RESTRICT to_encode, struct
output->buffer[written++] = HEX_CHARS[to_encode->ptr[i] & 0x0f];
}

output->buffer[written] = '\0';
AWS_ASSERT(written == encoded_len);
output->len = encoded_len;

return AWS_OP_SUCCESS;
Expand Down Expand Up @@ -153,6 +155,10 @@ static int s_hex_decode_char_to_int(char character, uint8_t *int_val) {
int aws_hex_compute_decoded_len(size_t to_decode_len, size_t *decoded_len) {
AWS_ASSERT(decoded_len);

/* For every 2 hex chars (rounded up) of encoded input, there will be 1 byte of decoded output.
* Rounding is because if buffer isn't even, we'll pretend there's an extra '0' at start of buffer */

/* adding 1 before dividing by 2 is a trick to round up during division */
size_t temp = (to_decode_len + 1);

if (AWS_UNLIKELY(temp < to_decode_len)) {
Expand Down Expand Up @@ -212,6 +218,10 @@ int aws_hex_decode(const struct aws_byte_cursor *AWS_RESTRICT to_decode, struct
int aws_base64_compute_encoded_len(size_t to_encode_len, size_t *encoded_len) {
AWS_ASSERT(encoded_len);

/* For every 3 bytes (rounded up) of unencoded input, there will be 4 ascii characters of encoded output.
* Rounding is because the output will be padded with '=' chars if necessary to make it divisible by 4. */

/* adding 2 before dividing by 3 is a trick to round up during division */
size_t tmp = to_encode_len + 2;

if (AWS_UNLIKELY(tmp < to_encode_len)) {
Expand All @@ -220,7 +230,7 @@ int aws_base64_compute_encoded_len(size_t to_encode_len, size_t *encoded_len) {

tmp /= 3;
size_t overflow_check = tmp;
tmp = 4 * tmp + 1; /* plus one for the NULL terminator */
tmp = 4 * tmp;

if (AWS_UNLIKELY(tmp < overflow_check)) {
return aws_raise_error(AWS_ERROR_OVERFLOW_DETECTED);
Expand All @@ -243,57 +253,51 @@ int aws_base64_compute_decoded_len(const struct aws_byte_cursor *AWS_RESTRICT to
return AWS_OP_SUCCESS;
}

/* ensure it's divisible by 4 */
if (AWS_UNLIKELY(len & 0x03)) {
return aws_raise_error(AWS_ERROR_INVALID_BASE64_STR);
}

size_t tmp = len * 3;

if (AWS_UNLIKELY(tmp < len)) {
return aws_raise_error(AWS_ERROR_OVERFLOW_DETECTED);
}
/* For every 4 ascii characters of encoded input, there will be 3 bytes of decoded output (deal with padding later)
* decoded_len = 3/4 * len <-- note that result will be smaller then len, so overflow can be avoided
* = (len / 4) * 3 <-- divide before multiply to avoid overflow
*/
size_t decoded_len_tmp = (len / 4) * 3;

/* But last two ascii chars might be padding. */
AWS_ASSERT(len >= 4); /* we checked earlier len != 0, and was divisible by 4 */
size_t padding = 0;

if (len >= 2 && input[len - 1] == '=' && input[len - 2] == '=') { /*last two chars are = */
if (input[len - 1] == '=' && input[len - 2] == '=') { /*last two chars are = */
padding = 2;
} else if (input[len - 1] == '=') { /*last char is = */
padding = 1;
}

*decoded_len = (tmp / 4 - padding);
*decoded_len = decoded_len_tmp - padding;
return AWS_OP_SUCCESS;
}

int aws_base64_encode(const struct aws_byte_cursor *AWS_RESTRICT to_encode, struct aws_byte_buf *AWS_RESTRICT output) {
AWS_ASSERT(to_encode->ptr);
AWS_ASSERT(output->buffer);
AWS_ASSERT(to_encode->len == 0 || to_encode->ptr != NULL);

size_t terminated_length = 0;
size_t encoded_length = 0;
if (AWS_UNLIKELY(aws_base64_compute_encoded_len(to_encode->len, &terminated_length))) {
if (AWS_UNLIKELY(aws_base64_compute_encoded_len(to_encode->len, &encoded_length))) {
return AWS_OP_ERR;
}

size_t needed_capacity = 0;
if (AWS_UNLIKELY(aws_add_size_checked(output->len, terminated_length, &needed_capacity))) {
if (AWS_UNLIKELY(aws_add_size_checked(output->len, encoded_length, &needed_capacity))) {
return AWS_OP_ERR;
}

if (AWS_UNLIKELY(output->capacity < needed_capacity)) {
return aws_raise_error(AWS_ERROR_SHORT_BUFFER);
}

/*
* For convenience to standard C functions expecting a null-terminated
* string, the output is terminated. As the encoding itself can be used in
* various ways, however, its length should never account for that byte.
*/
encoded_length = (terminated_length - 1);
AWS_ASSERT(needed_capacity == 0 || output->buffer != NULL);

if (aws_common_private_has_avx2()) {
aws_common_private_base64_encode_sse41(to_encode->ptr, output->buffer + output->len, to_encode->len);
output->buffer[output->len + encoded_length] = 0;
output->len += encoded_length;
return AWS_OP_SUCCESS;
}
Expand Down Expand Up @@ -329,9 +333,6 @@ int aws_base64_encode(const struct aws_byte_cursor *AWS_RESTRICT to_encode, stru
}
}

/* it's a string add the null terminator. */
output->buffer[output->len + encoded_length] = 0;

output->len += encoded_length;

return AWS_OP_SUCCESS;
Expand Down
Loading
Loading