Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match CEL and Go duration literal parsing, while preserving the full range of values #38

Merged
merged 9 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ linters:
- wsl # generous whitespace violates house style
- exhaustive
- exhaustruct
- nonamedreturns
- mnd
- err113
- gochecknoglobals
issues:
exclude:
# Don't ban use of fmt.Errorf to create new errors, but the remaining
Expand Down
263 changes: 197 additions & 66 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"math"
"math/big"
"strconv"
"strings"
"time"
Expand All @@ -34,8 +35,6 @@ import (
"gopkg.in/yaml.v3"
)

const atTypeFieldName = "@type"

// Validator is an interface for validating a Protobuf message produced from a given YAML node.
type Validator interface {
// Validate the given message.
Expand All @@ -57,11 +56,6 @@ type UnmarshalOptions struct {
}
}

type protoResolver interface {
protoregistry.MessageTypeResolver
protoregistry.ExtensionTypeResolver
}

// Unmarshal a Protobuf message from the given YAML data.
func Unmarshal(data []byte, message proto.Message) error {
return (UnmarshalOptions{}).Unmarshal(data, message)
Expand All @@ -76,6 +70,53 @@ func (o UnmarshalOptions) Unmarshal(data []byte, message proto.Message) error {
return o.unmarshalNode(&yamlFile, message, data)
}

// ParseDuration parses a duration string into a durationpb.Duration.
//
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
//
// This function supports the full range of durationpb.Duration values, including
// those outside the range of time.Duration.
func ParseDuration(str string) (*durationpb.Duration, error) {
// [-+]?([0-9]*(\.[0-9]*)?[a-z]+)+
neg := false

// Consume [-+]?
if str != "" {
c := str[0]
if c == '-' || c == '+' {
neg = c == '-'
str = str[1:]
}
}
// Special case: if all that is left is "0", this is zero.
if str == "0" {
var empty *durationpb.Duration
return empty, nil
}
if str == "" {
return nil, errors.New("invalid duration")
}
totalNanos := &big.Int{}
var err error
for str != "" {
str, err = parseDurationNext(str, totalNanos)
if err != nil {
return nil, err
}
}
if neg {
totalNanos.Neg(totalNanos)
}
result := &durationpb.Duration{}
quo, rem := totalNanos.QuoRem(totalNanos, nanosPerSecond, &big.Int{})
if !quo.IsInt64() {
return nil, errors.New("invalid duration: out of range")
}
result.Seconds = quo.Int64()
result.Nanos = int32(rem.Int64())
return result, nil
}

func (o UnmarshalOptions) unmarshalNode(node *yaml.Node, message proto.Message, data []byte) error {
if node.Kind == 0 {
return nil
Expand Down Expand Up @@ -121,6 +162,13 @@ func (o UnmarshalOptions) unmarshalNode(node *yaml.Node, message proto.Message,
return nil
}

const atTypeFieldName = "@type"

type protoResolver interface {
protoregistry.MessageTypeResolver
protoregistry.ExtensionTypeResolver
}

type unmarshaler struct {
options UnmarshalOptions
errors []error
Expand Down Expand Up @@ -683,54 +731,6 @@ const (
minTimestampSeconds = -62135596800
)

// Format is decimal seconds with up to 9 fractional digits, followed by an 's'.
func parseDuration(txt string, duration *durationpb.Duration) error {
// Remove trailing s.
txt = strings.TrimSpace(txt)
if len(txt) == 0 || txt[len(txt)-1] != 's' {
return errors.New("missing trailing 's'")
}
value := txt[:len(txt)-1]
isNeg := strings.HasPrefix(value, "-")

// Split into seconds and nanos.
parts := strings.Split(value, ".")
switch len(parts) {
case 1:
// seconds only
seconds, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return err
}
duration.Seconds = seconds
duration.Nanos = 0
case 2:
// seconds and up to 9 digits of fractional seconds
seconds, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return err
}
duration.Seconds = seconds
nanos, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return err
}
power := 9 - len(parts[1])
if power < 0 {
return errors.New("too many fractional second digits")
}
nanos *= int64(math.Pow10(power))
if isNeg {
duration.Nanos = -int32(nanos)
} else {
duration.Nanos = int32(nanos)
}
default:
return errors.New("invalid duration: too many '.' characters")
}
return nil
}

// Format is RFC3339Nano, limited to the range 0001-01-01T00:00:00Z to
// 9999-12-31T23:59:59Z inclusive.
func parseTimestamp(txt string, timestamp *timestamppb.Timestamp) error {
Expand Down Expand Up @@ -770,19 +770,21 @@ func unmarshalDurationMsg(unm *unmarshaler, node *yaml.Node, message proto.Messa
if node.Kind != yaml.ScalarNode || len(node.Value) == 0 || isNull(node) {
return false
}
duration, ok := message.(*durationpb.Duration)
if !ok {
duration = &durationpb.Duration{}
}
err := parseDuration(node.Value, duration)
duration, err := ParseDuration(node.Value)
if err != nil {
unm.addErrorf(node, "invalid duration: %v", err)
} else if !ok {
// Set the fields dynamically.
return setFieldByName(message, "seconds", protoreflect.ValueOfInt64(duration.GetSeconds())) &&
setFieldByName(message, "nanos", protoreflect.ValueOfInt32(duration.GetNanos()))
unm.addError(node, err)
return true
}
return true

if value, ok := message.(*durationpb.Duration); ok {
value.Seconds = duration.GetSeconds()
value.Nanos = duration.GetNanos()
return true
}

// Set the fields dynamically.
return setFieldByName(message, "seconds", protoreflect.ValueOfInt64(duration.GetSeconds())) &&
setFieldByName(message, "nanos", protoreflect.ValueOfInt32(duration.GetNanos()))
}

func unmarshalTimestampMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
Expand Down Expand Up @@ -1184,3 +1186,132 @@ func findEntryByKey(cur *yaml.Node, key string) (*yaml.Node, *yaml.Node, bool) {
}
return nil, cur, false
}

// nanosPerSecond is the number of nanoseconds in a second.
var nanosPerSecond = new(big.Int).SetUint64(uint64(time.Second / time.Nanosecond))

// nanosMap is a map of time unit names to their duration in nanoseconds.
var nanosMap = map[string]*big.Int{
"ns": new(big.Int).SetUint64(1), // Identity for nanos.
"us": new(big.Int).SetUint64(uint64(time.Microsecond / time.Nanosecond)),
"µs": new(big.Int).SetUint64(uint64(time.Microsecond / time.Nanosecond)), // U+00B5 = micro symbol
"μs": new(big.Int).SetUint64(uint64(time.Microsecond / time.Nanosecond)), // U+03BC = Greek letter mu
"ms": new(big.Int).SetUint64(uint64(time.Millisecond / time.Nanosecond)),
"s": nanosPerSecond,
"m": new(big.Int).SetUint64(uint64(time.Minute / time.Nanosecond)),
"h": new(big.Int).SetUint64(uint64(time.Hour / time.Nanosecond)),
}

// unitsNames is the (normalized) list of time unit names.
var unitsNames = []string{"h", "m", "s", "ms", "us", "ns"}
Alfus marked this conversation as resolved.
Show resolved Hide resolved

// parseDurationNest parses a single segment of the duration string.
func parseDurationNext(str string, totalNanos *big.Int) (string, error) {
// The next character must be [0-9.]
if !(str[0] == '.' || '0' <= str[0] && str[0] <= '9') {
return "", errors.New("invalid duration")
}
var err error
var whole, frac uint64
var pre bool // Whether we have seen a digit before the dot.
whole, str, pre, err = leadingInt(str)
if err != nil {
return "", err
}
var scale *big.Int
var post bool // Whether we have seen a digit after the dot.
if str != "" && str[0] == '.' {
str = str[1:]
frac, scale, str, post = leadingFrac(str)
}
if !pre && !post {
return "", errors.New("invalid duration")
}

end := unitEnd(str)
if end == 0 {
return "", fmt.Errorf("invalid duration: missing unit, expected one of %v", unitsNames)
}
unitName := str[:end]
str = str[end:]
nanosPerUnit, ok := nanosMap[unitName]
if !ok {
return "", fmt.Errorf("invalid duration: unknown unit, expected one of %v", unitsNames)
}

// Convert to nanos and add to total.
// totalNanos += whole * nanosPerUnit + frac * nanosPerUnit / scale
if whole > 0 {
wholeNanos := &big.Int{}
wholeNanos.SetUint64(whole)
wholeNanos.Mul(wholeNanos, nanosPerUnit)
totalNanos.Add(totalNanos, wholeNanos)
}
if frac > 0 {
fracNanos := &big.Int{}
fracNanos.SetUint64(frac)
fracNanos.Mul(fracNanos, nanosPerUnit)
rem := &big.Int{}
fracNanos.QuoRem(fracNanos, scale, rem)
if rem.Uint64() > 0 {
return "", errors.New("invalid duration: fractional nanos")
}
totalNanos.Add(totalNanos, fracNanos)
}
return str, nil
}

func unitEnd(str string) int {
var i int
for ; i < len(str); i++ {
c := str[i]
if c == '.' || c == '-' || '0' <= c && c <= '9' {
return i
}
}
return i
}

func leadingFrac(str string) (result uint64, scale *big.Int, rem string, post bool) {
var i int
scale = big.NewInt(1)
big10 := big.NewInt(10)
var overflow bool
for ; i < len(str); i++ {
chr := str[i]
if chr < '0' || chr > '9' {
break
}
if overflow {
continue
}
if result > (1<<63-1)/10 {
overflow = true
continue
}
temp := result*10 + uint64(chr-'0')
if temp > 1<<63 {
overflow = true
continue
}
result = temp
scale.Mul(scale, big10)
}
return result, scale, str[i:], i > 0
}

func leadingInt(str string) (result uint64, rem string, pre bool, err error) {
var i int
for ; i < len(str); i++ {
c := str[i]
if c < '0' || c > '9' {
break
}
newResult := result*10 + uint64(c-'0')
if newResult < result {
return 0, str, i > 0, errors.New("integer overflow")
}
result = newResult
}
return result, str[i:], i > 0, nil
}
Loading