Skip to content

Commit

Permalink
add rounding logic and scale zero fix
Browse files Browse the repository at this point in the history
  • Loading branch information
himadripal committed Feb 23, 2025
1 parent 296e0fd commit bbd54d4
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 13 deletions.
11 changes: 4 additions & 7 deletions arrow-cast/src/cast/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -603,21 +603,18 @@ mod tests {
12300000_i128
);

// `parse_decimal` does not handle scale=0 correctly. will enable it as part of code change PR.
// assert_eq!(parse_decimal::<Decimal128Type>("123.45", 38, 0)?, 123_i128);
assert_eq!(parse_decimal::<Decimal128Type>("123.45", 38, 0)?, 123_i128);
assert_eq!(
parse_decimal::<Decimal128Type>("123.45", 38, 5)?,
12345000_i128
);

//scale = 0 is not handled correctly in parse_decimal, next PR will fix it and enable this.
/*assert_eq!(
assert_eq!(
parse_decimal::<Decimal128Type>("123.4567891", 38, 0)?,
123_i128
);*/
);
assert_eq!(
parse_decimal::<Decimal128Type>("123.4567891", 38, 5)?,
12345678_i128
12345679_i128
);
Ok(())
}
Expand Down
6 changes: 3 additions & 3 deletions arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8453,15 +8453,15 @@ mod tests {
38,
2,
),
"0.12"
"0.13"
);
assert_eq!(
Decimal128Type::format_decimal(
parse_decimal::<Decimal128Type>(".1265", 38, 2).unwrap(),
38,
2,
),
"0.12"
"0.13"
);

assert_eq!(
Expand Down Expand Up @@ -8502,7 +8502,7 @@ mod tests {
38,
3,
),
"0.126"
"0.127"
);
}

Expand Down
36 changes: 33 additions & 3 deletions arrow-cast/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,16 @@ fn parse_e_notation<T: DecimalType>(
}

if exp < 0 {
result = result.div_wrapping(base.pow_wrapping(-exp as _));
let result_with_scale = result.div_wrapping(base.pow_wrapping(-exp as _));
let result_with_one_scale_up =
result.div_wrapping(base.pow_wrapping(-exp.add_wrapping(1) as _));
let rounding_digit =
result_with_one_scale_up.sub_wrapping(result_with_scale.mul_wrapping(base));
if rounding_digit >= T::Native::usize_as(5) {
result = result_with_scale.add_wrapping(T::Native::usize_as(1));
} else {
result = result_with_scale;
}
} else {
result = result.mul_wrapping(base.pow_wrapping(exp as _));
}
Expand All @@ -868,6 +877,7 @@ pub fn parse_decimal<T: DecimalType>(
let mut result = T::Native::usize_as(0);
let mut fractionals: i8 = 0;
let mut digits: u8 = 0;
let mut rounding_digit = -1; // to store digit after the scale for rounding
let base = T::Native::usize_as(10);

let bs = s.as_bytes();
Expand Down Expand Up @@ -897,6 +907,13 @@ pub fn parse_decimal<T: DecimalType>(
// Ignore leading zeros.
continue;
}
if fractionals == scale && scale != 0 {
// Capture the rounding digit once
if rounding_digit < 0 {
rounding_digit = (b - b'0') as i8;
}
continue;
}
digits += 1;
result = result.mul_wrapping(base);
result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize));
Expand Down Expand Up @@ -925,11 +942,17 @@ pub fn parse_decimal<T: DecimalType>(
"can't parse the string value {s} to decimal"
)));
}
if fractionals == scale && scale != 0 {
if fractionals == scale {
// Capture the rounding digit once
if rounding_digit < 0 {
rounding_digit = (b - b'0') as i8;
}
// We have processed all the digits that we need. All that
// is left is to validate that the rest of the string contains
// valid digits.
continue;
if scale != 0 {
continue;
}
}
fractionals += 1;
digits += 1;
Expand Down Expand Up @@ -986,6 +1009,13 @@ pub fn parse_decimal<T: DecimalType>(
"parse decimal overflow ({s})"
)));
}
if scale == 0 {
result = result.div_wrapping(base.pow_wrapping(fractionals as u32))
}
//add one if >=5
if rounding_digit >= 5 {
result = result.add_wrapping(T::Native::usize_as(1));
}
}

Ok(if negative {
Expand Down

0 comments on commit bbd54d4

Please sign in to comment.