diff --git a/pkg/stdlib/errors.go b/pkg/stdlib/errors.go new file mode 100644 index 0000000..e57e706 --- /dev/null +++ b/pkg/stdlib/errors.go @@ -0,0 +1,50 @@ +package stdlib + +import "fmt" + +type ArgumentError struct { + name string + wrapped error +} + +func NewArgumentError(name string, wrapped error) ArgumentError { + return ArgumentError{name: name, wrapped: wrapped} +} + +func (e ArgumentError) Error() string { + return fmt.Sprintf("%s(): argument error: %s", e.name, e.wrapped) +} + +func (e ArgumentError) Unwrap() error { + return e.wrapped +} + +type InvalidNumberOfArgumentsError struct { + Expected int + Actual int + Message string +} + +func NewInvalidNumberOfArgumentsError(name, message string, expected, actual int) error { + return NewArgumentError(name, InvalidNumberOfArgumentsError{Message: message, Actual: actual, Expected: expected}) +} + +func (e InvalidNumberOfArgumentsError) Error() string { + if e.Message == "" { + return fmt.Sprintf("accepts exactly %d argument, %d provided", e.Expected, e.Actual) + } + + return fmt.Sprintf(e.Message, e.Expected, e.Actual) +} + +type InvalidArgumentTypeError struct { + Message string +} + +func (e InvalidArgumentTypeError) Error() string { + return e.Message +} + +func NewInvalidArgumentTypeError(name, message string) error { + return NewArgumentError(name, InvalidArgumentTypeError{Message: message}) +} diff --git a/pkg/stdlib/functions.go b/pkg/stdlib/functions.go index 4133a0c..95255dc 100644 --- a/pkg/stdlib/functions.go +++ b/pkg/stdlib/functions.go @@ -2,12 +2,11 @@ package stdlib import ( "cmp" - "errors" "fmt" "path/filepath" - "reflect" "slices" "strings" + "time" "github.com/expr-lang/expr" "github.com/xhit/go-str2duration/v2" @@ -17,12 +16,12 @@ var FilepathDir = expr.Function( "filepath_dir", func(params ...any) (any, error) { if len(params) != 1 { - return nil, fmt.Errorf("filepath_dir: accepts exactly 1 argument, %d provided", len(params)) + return nil, NewInvalidNumberOfArgumentsError("filepath_dir", "", 1, len(params)) } val, ok := params[0].(string) if !ok { - return nil, errors.New("input to filepath_dir must be of type 'string'") + return nil, NewInvalidArgumentTypeError("filepath_dir", "input must be string") } return filepath.Dir(val), nil @@ -36,61 +35,74 @@ func UniqSlice[T cmp.Ordered](in []T) []T { return slices.Compact(in) } +// Uniq takes a list of strings or interface{}, sorts them +// and remove duplicated values var Uniq = expr.Function( "uniq", - func(params ...any) (any, error) { - arg := params[0] - val := reflect.ValueOf(arg) - - switch val.Kind() { //nolint:exhaustive - case reflect.Slice: - switch val.Type().Elem().Kind() { //nolint:exhaustive - case reflect.Interface: - var x []string - for _, v := range arg.([]any) { //nolint:forcetypeassert - x = append(x, fmt.Sprintf("%s", v)) - } + func(args ...any) (any, error) { + if len(args) != 1 { + return nil, NewInvalidNumberOfArgumentsError("uniq", "", 1, len(args)) + } - return UniqSlice(x), nil + arg := args[0] - case reflect.String: - return UniqSlice(arg.([]string)), nil //nolint:forcetypeassert + switch val := arg.(type) { + case []any: + var result []string + for _, v := range val { + result = append(result, fmt.Sprintf("%s", v)) } - } - return nil, errors.New("invalid type") + return UniqSlice(result), nil + + case []string: + return UniqSlice(val), nil + + default: + return nil, NewInvalidArgumentTypeError("uniq", fmt.Sprintf("invalid input, must be an array of [string] or [interface], got %T", arg)) + } }, + new(func([]string) []string), + new(func([]any) []string), ) +// Override built-in duration() function to provide support for additional periods +// - 'd' (day) +// - 'w' (week) +// - 'm' (month) var Duration = expr.Function( "duration", func(args ...any) (any, error) { + if len(args) != 1 { + return nil, NewInvalidNumberOfArgumentsError("duration", "", 1, len(args)) + } + val, ok := args[0].(string) if !ok { - return nil, errors.New("input to duration() must be of type 'string'") + return nil, NewInvalidArgumentTypeError("duration", fmt.Sprintf("invalid input, must be a string, got %T", args[0])) } return str2duration.ParseDuration(val) }, - str2duration.ParseDuration, + time.ParseDuration, ) var LimitPathDepthTo = expr.Function( "limit_path_depth_to", func(args ...any) (any, error) { if len(args) != 2 { - return nil, errors.New("limit_path_depth_to() expect exactly two arguments") + return nil, NewInvalidNumberOfArgumentsError("limit_path_depth_to", "", 2, len(args)) } input, ok := args[0].(string) if !ok { - return nil, errors.New("first input to limit_path_depth_to() must be of type 'string'") + return nil, NewInvalidArgumentTypeError("limit_path_depth_to", fmt.Sprintf("invalid input, first argument must be a 'string', got %T", args[0])) } length, ok := args[1].(int) if !ok { - return nil, errors.New("second input to limit_path_depth_to() must be of type 'int'") + return nil, NewInvalidArgumentTypeError("limit_path_depth_to", fmt.Sprintf("invalid input, first argument must be an 'int', got %T", args[0])) } chunks := strings.Split(input, "/") @@ -100,4 +112,5 @@ var LimitPathDepthTo = expr.Function( return strings.Join(chunks[0:length-1], "/"), nil // nosemgrep }, + new(func(string, int) string), ) diff --git a/pkg/stdlib/stdlib.go b/pkg/stdlib/stdlib.go index d42b2be..bcb6e46 100644 --- a/pkg/stdlib/stdlib.go +++ b/pkg/stdlib/stdlib.go @@ -1,6 +1,8 @@ package stdlib -import "github.com/expr-lang/expr" +import ( + "github.com/expr-lang/expr" +) var Functions = []expr.Option{ // Replace built-in duration function with one that supports "d" (days) and "w" (weeks)