From 940dda2d5618bc19712e0d421e5c4aeb81900281 Mon Sep 17 00:00:00 2001 From: Himanshu Rai Date: Thu, 30 Jan 2025 17:32:22 +0530 Subject: [PATCH 1/5] Add support for custom validation tags * Introduce converter options and use it in NewConverter * Add new option to specify handlers for custom validation tags * Add support for custom validation tags in all types * Improve ignore tags check for all types --- README.md | 154 ++++++--- custom/decimal/decimal_test.go | 3 +- custom/optional/optional_test.go | 3 +- zod.go | 545 +++++++++++++++++++++---------- zod_test.go | 60 +++- 5 files changed, 539 insertions(+), 226 deletions(-) diff --git a/README.md b/README.md index 7fcf48e..deaa1ba 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Zod + Generate = Zen -Converts Go structs with go-validator validations to Zod schemas. +Converts Go structs with [go-validator](https://github.com/go-playground/validator) validations to Zod schemas. Zen supports self-referential types and generic types. Other cyclic types (apart from self referential types) are not supported as they are not supported by zod itself. @@ -34,7 +34,7 @@ type Tree struct { fmt.Print(zen.StructToZodSchema(Tree{})) // We can also use create a converter and convert multiple types together -c := zen.NewConverter(nil) +c := zen.NewConverter() // Generic types are also supported type GenericPair[T any, U any] struct { @@ -123,42 +123,6 @@ export type PairMapStringIntBool = z.infer schema := converter.Export() ``` -## Custom Types - -We can pass type name mappings to custom conversion functions: - -```go -c := zen.NewConverter(map[string]zen.CustomFn{ - "github.com/shopspring/decimal.Decimal": func (c *zen.Converter, t reflect.Type, v string, i int) string { - // Shopspring's decimal type serialises to a string. - return "z.string()" - }, -}) - -c.Convert(User{ - Money decimal.Decimal -}) -``` - -Outputs: - -```typescript -export const UserSchema = z.object({ - Money: z.string(), -}) -export type User = z.infer -``` - -There are some custom types with tests in the "custom" directory. - -The function signature for custom type handlers is: - -```go -func(c *Converter, t reflect.Type, validate string, indent int) string -``` - -We can use `c` to process nested types. Indent level is for passing to other converter APIs. - ## Supported validations ### Network @@ -248,6 +212,120 @@ We can use `c` to process nested types. Indent level is for passing to other con - required checks that the value is not default, but we are not implementing this check for numbers and booleans +## Custom Tags + +In addition to the [go-validator](https://github.com/go-playground/validator) tags supported out of the box, custom tags can also be implemented. + +```go +type SortParams struct { + Order *string `json:"order,omitempty" validate:"omitempty,oneof=asc desc"` + Field *string `json:"field,omitempty"` +} + +type Request struct { + SortParams `validate:"sortFields=title address age dob"` + PaginationParams struct { + Start *int `json:"start,omitempty" validate:"omitempty,gt=0"` + End *int `json:"end,omitempty" validate:"omitempty,gt=0"` + } `validate:"pageParams"` + Search *string `json:"search,omitempty" validate:"identifier"` +} + +customTagHandlers := map[string]zen.CustomFn{ + "identifier": func(c *zen.Converter, t reflect.Type, validate string, indent int) string { + return ".refine((val) => !val || /^[a-z0-9_]*$/.test(val), 'Invalid search identifier')" + }, + "pageParams": func(c *zen.Converter, t reflect.Type, validate string, indent int) string { + return ".refine((val) => !val.start || !val.end || val.start < val.end, 'Start should be less than end')" + }, + "sortFields": func(c *zen.Converter, t reflect.Type, validate string, indent int) string { + sortFields := strings.Split(validate, " ") + for i := range sortFields { + sortFields[i] = fmt.Sprintf("'%s'", sortFields[i]) + } + return fmt.Sprintf(".extend({field: z.enum([%s])})", strings.Join(sortFields, ", ")) + }, +} +opt := zen.WithCustomTags(customTagHandlers) +c := zen.NewConverter(opt) + +c.Convert(Request{}) +``` + +Outputs: + +```ts +export const SortParamsSchema = z.object({ + order: z.enum(["asc", "desc"] as const).optional(), + field: z.string().optional(), +}) +export type SortParams = z.infer + +export const RequestSchema = z.object({ + PaginationParams: z.object({ + start: z.number().gt(0).optional(), + end: z.number().gt(0).optional(), + }).refine((val) => !val.start || !val.end || val.start < val.end, 'Start should be less than end'), + search: z.string().refine((val) => !val || /^[a-z0-9_]*$/.test(val), 'Invalid search identifier').optional(), +}).merge(SortParamsSchema.extend({field: z.enum(['title', 'address', 'age', 'dob'])})) +export type Request = z.infer +``` + +The function signature for custom type handlers is: + +```go +func(c *Converter, t reflect.Type, validate string, indent int) string +``` + +We can use `c` to process nested types. Indent level is for passing to other converter APIs. + +## Ignored Tags + +To ensure safety, `zen` will panic if it encounters unknown validation tags. If these tags are intentional, they should be explicitly ignored. + +```go +opt := zen.WithIgnoreTags("identifier") +c := zen.NewConverter(opt) +``` + +## Custom Types + +We can pass type name mappings to custom conversion functions: + +```go +customTypeHandlers := map[string]zen.CustomFn{ + "github.com/shopspring/decimal.Decimal": func (c *zen.Converter, t reflect.Type, v string, indent int) string { + // Shopspring's decimal type serialises to a string. + return "z.string()" + }, +} +opt := zen.WithCustomTypes(customTypeHandlers) +c := zen.NewConverter(opt) + +c.Convert(User{ + Money decimal.Decimal +}) +``` + +Outputs: + +```typescript +export const UserSchema = z.object({ + Money: z.string(), +}) +export type User = z.infer +``` + +There are some custom types with tests in the [custom](./custom) directory. + +The function signature for custom type handlers is: + +```go +func(c *Converter, t reflect.Type, validate string, indent int) string +``` + +We can use `c` to process nested types. Indent level is for passing to other converter APIs. + ## Caveats - Does not support cyclic types - it's a limitation of zod, but self-referential types are supported. diff --git a/custom/decimal/decimal_test.go b/custom/decimal/decimal_test.go index 7cc2b83..25790be 100644 --- a/custom/decimal/decimal_test.go +++ b/custom/decimal/decimal_test.go @@ -11,9 +11,10 @@ import ( ) func TestCustom(t *testing.T) { - c := zen.NewConverter(map[string]zen.CustomFn{ + opt := zen.WithCustomTypes(map[string]zen.CustomFn{ customDecimal.DecimalType: customDecimal.DecimalFunc, }) + c := zen.NewConverter(opt) type User struct { Money decimal.Decimal diff --git a/custom/optional/optional_test.go b/custom/optional/optional_test.go index 6d8599c..36c7c02 100644 --- a/custom/optional/optional_test.go +++ b/custom/optional/optional_test.go @@ -10,9 +10,10 @@ import ( ) func TestCustom(t *testing.T) { - c := zen.NewConverter(map[string]zen.CustomFn{ + opt := zen.WithCustomTypes(map[string]zen.CustomFn{ customoptional.OptionalType: customoptional.OptionalFunc, }) + c := zen.NewConverter(opt) type Profile struct { Bio string diff --git a/zod.go b/zod.go index 64fcc55..22e45cb 100644 --- a/zod.go +++ b/zod.go @@ -9,21 +9,65 @@ import ( "strings" ) -// NewConverter initializes and returns a new converter instance. The custom handler -// function map should be keyed on the fully qualified type name (excluding generic -// type arguments), ie. package.typename. -func NewConverter(custom map[string]CustomFn) Converter { - c := Converter{ - prefix: "", - outputs: make(map[string]entry), - custom: custom, +// Opt represents a converter option used to modify its behavior. +type Opt func(*Converter) + +// Adds prefix to the generated schema and type names. +func WithPrefix(prefix string) Opt { + return func(c *Converter) { + c.prefix = prefix + } +} + +// Adds custom handler/converters for types. The map should be keyed on +// the fully qualified type name (excluding generic type arguments), ie. +// package.typename. +func WithCustomTypes(custom map[string]CustomFn) Opt { + return func(c *Converter) { + for k, v := range custom { + c.customTypes[k] = v + } + } +} + +// Adds custom handler/converts for tags. The functions should return +// strings like `.regex(/[a-z0-9_]+/)` or `.refine((val) => val !== 0)` +// which can be appended to the generated schema. +func WithCustomTags(custom map[string]CustomFn) Opt { + return func(c *Converter) { + for k, v := range custom { + c.customTags[k] = v + } + } +} + +// Adds tags which should be ignored. Any unrecognized tag (which is also +// not ignored) results in panic. +func WithIgnoreTags(ignores ...string) Opt { + return func(c *Converter) { + c.ignoreTags = append(c.ignoreTags, ignores...) + } +} + +// NewConverter initializes and returns a new converter instance. +func NewConverter(opts ...Opt) *Converter { + c := &Converter{ + prefix: "", + customTypes: make(map[string]CustomFn), + customTags: make(map[string]CustomFn), + ignoreTags: []string{}, + outputs: make(map[string]entry), + } + + for _, opt := range opts { + opt(c) } return c } // AddType converts a struct type to corresponding zod schema. AddType can be called -// multiple times, followed by Export to get the corresonding zod schemas. +// multiple times, followed by Export to get the corresponding zod schemas. func (c *Converter) AddType(input interface{}) { t := reflect.TypeOf(input) @@ -65,24 +109,8 @@ func (c *Converter) ConvertSlice(inputs []interface{}) string { } // StructToZodSchema returns zod schema corresponding to a struct type. -func StructToZodSchema(input interface{}) string { - c := Converter{ - prefix: "", - outputs: make(map[string]entry), - } - - return c.Convert(input) -} - -// StructToZodSchemaWithPrefix returns zod schema corresponding to a struct type. -// The prefix is added to the generated schema and type names. -func StructToZodSchemaWithPrefix(prefix string, input interface{}) string { - c := Converter{ - prefix: prefix, - outputs: make(map[string]entry), - } - - return c.Convert(input) +func StructToZodSchema(input interface{}, opts ...Opt) string { + return NewConverter(opts...).Convert(input) } var typeMapping = map[reflect.Kind]string{ @@ -125,12 +153,13 @@ type meta struct { } type Converter struct { - prefix string - structs int - outputs map[string]entry - custom map[string]CustomFn - stack []meta - ignores []string + prefix string + customTypes map[string]CustomFn + customTags map[string]CustomFn + ignoreTags []string + structs int + outputs map[string]entry + stack []meta } func (c *Converter) addSchema(name string, data string) { @@ -313,7 +342,7 @@ func getFullName(t reflect.Type) string { func (c *Converter) handleCustomType(t reflect.Type, validate string, indent int) (string, bool) { fullName := getFullName(t) - custom, ok := c.custom[fullName] + custom, ok := c.customTypes[fullName] if ok { return custom(c, t, validate, indent), true } @@ -345,31 +374,79 @@ func (c *Converter) ConvertType(t reflect.Type, validate string, indent int) str } if t.Kind() == reflect.Struct { + var validateStr strings.Builder + var refines []string name := typeName(t) + parts := strings.Split(validate, ",") if name == "" { // Handle fields with non-defined types - these are inline. - return c.convertStruct(t, indent) + validateStr.WriteString(c.convertStruct(t, indent)) } else if name == "Time" { - var validateStr string - if validate != "" { - // We compare with both the zero value from go and the zero value that zod coerces to - if validate == "required" { - validateStr = ".refine((val) => val.getTime() !== new Date('0001-01-01T00:00:00Z').getTime() && val.getTime() !== new Date(0).getTime(), 'Invalid date')" - } - } // timestamps are to be coerced to date by zod. JSON.parse only serializes to string - return "z.coerce.date()" + validateStr + validateStr.WriteString("z.coerce.date()") } else { if c.stack[len(c.stack)-1].name == name { c.stack[len(c.stack)-1].selfRef = true - return fmt.Sprintf("z.lazy(() => %s)", schemaName(c.prefix, name)) + validateStr.WriteString(fmt.Sprintf("z.lazy(() => %s)", schemaName(c.prefix, name))) + } else { + // throws panic if there is a cycle + detectCycle(name, c.stack) + c.addSchema(name, c.convertStructTopLevel(t)) + validateStr.WriteString(schemaName(c.prefix, name)) + } + } + + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + + idx := strings.Index(part, "=") + if idx == 0 || idx == len(part)-1 { + panic(fmt.Sprintf("invalid validation: %s", part)) + } + + var valName string + var valValue string + if idx == -1 { + valName = part + } else { + valName = part[:idx] + valValue = part[idx+1:] + } + + if c.checkIsIgnored(valName) { + continue + } + + if h, ok := c.customTags[valName]; ok { + v := h(c, reflect.TypeOf(0), valValue, 0) + if strings.HasPrefix(v, ".refine") { + refines = append(refines, v) + } else { + validateStr.WriteString(v) + } + continue + } + + switch valName { + case "required": + if name == "Time" { + // We compare with both the zero value from go and the zero value that zod coerces to + refines = append(refines, ".refine((val) => val.getTime() !== new Date('0001-01-01T00:00:00Z').getTime() && val.getTime() !== new Date(0).getTime(), 'Invalid date')") + } + default: + panic(fmt.Sprintf("unknown validation: %s", part)) } - // throws panic if there is a cycle - detectCycle(name, c.stack) - c.addSchema(name, c.convertStructTopLevel(t)) - return schemaName(c.prefix, name) } + + for _, refine := range refines { + validateStr.WriteString(refine) + } + + return validateStr.String() } // boolean, number, string, any @@ -441,7 +518,7 @@ func (c *Converter) convertField(f reflect.StructField, indent int, optional, nu // because nullability is processed before custom types, this makes sure // the custom type has control over nullability. fullName := getFullName(f.Type) - _, isCustom := c.custom[fullName] + _, isCustom := c.customTypes[fullName] optionalCall := "" if optional { @@ -477,7 +554,7 @@ func (c *Converter) getTypeField(f reflect.StructField, indent int, optional, nu // because nullability is processed before custom types, this makes sure // the custom type has control over nullability. fullName := getFullName(f.Type) - _, isCustom := c.custom[fullName] + _, isCustom := c.customTypes[fullName] optionalCallPre := "" optionalCallUndef := "" @@ -501,66 +578,105 @@ func (c *Converter) getTypeField(f reflect.StructField, indent int, optional, nu } func (c *Converter) convertSliceAndArray(t reflect.Type, validate string, indent int) string { - if t.Kind() == reflect.Array { - return fmt.Sprintf( - "%s.array()%s", - c.ConvertType(t.Elem(), getValidateAfterDive(validate), indent), fmt.Sprintf(".length(%d)", t.Len())) - } - var validateStr strings.Builder + var refines []string validateCurrent := getValidateCurrent(validate) - if validateCurrent != "" { - parts := strings.Split(validateCurrent, ",") + parts := strings.Split(validateCurrent, ",") + isArray := t.Kind() == reflect.Array - // eq and ne should be at the end since they output a refine function - sort.SliceStable(parts, func(i, j int) bool { - if strings.HasPrefix(parts[i], "ne") { - return false - } - if strings.HasPrefix(parts[j], "ne") { - return true + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + + idx := strings.Index(part, "=") + if idx == 0 || idx == len(part)-1 { + panic(fmt.Sprintf("invalid validation: %s", part)) + } + + var valName string + var valValue string + if idx == -1 { + valName = part + } else { + valName = part[:idx] + valValue = part[idx+1:] + } + + if c.checkIsIgnored(valName) { + continue + } + + if h, ok := c.customTags[valName]; ok { + v := h(c, reflect.TypeOf(0), valValue, 0) + if strings.HasPrefix(v, ".refine") { + refines = append(refines, v) + } else { + validateStr.WriteString(v) } - return i < j - }) + continue + } - for _, part := range parts { - part = strings.TrimSpace(part) - if part == "omitempty" { - } else if part == "dive" { - break - } else if part == "required" { - } else if strings.HasPrefix(part, "min=") { - validateStr.WriteString(fmt.Sprintf(".min(%s)", part[4:])) - } else if strings.HasPrefix(part, "max=") { - validateStr.WriteString(fmt.Sprintf(".max(%s)", part[4:])) - } else if strings.HasPrefix(part, "len=") { - validateStr.WriteString(fmt.Sprintf(".length(%s)", part[4:])) - } else if strings.HasPrefix(part, "eq=") { - validateStr.WriteString(fmt.Sprintf(".length(%s)", part[3:])) - } else if strings.HasPrefix(part, "ne=") { - validateStr.WriteString(fmt.Sprintf(".refine((val) => val.length !== %s)", part[3:])) - } else if strings.HasPrefix(part, "gt=") { - val, err := strconv.Atoi(part[3:]) - if err != nil || val < 0 { - panic(fmt.Sprintf("invalid gt value: %s", part[3:])) - } - validateStr.WriteString(fmt.Sprintf(".min(%d)", val+1)) - } else if strings.HasPrefix(part, "gte=") { - validateStr.WriteString(fmt.Sprintf(".min(%s)", part[4:])) - } else if strings.HasPrefix(part, "lt=") { - val, err := strconv.Atoi(part[3:]) - if err != nil || val <= 0 { - panic(fmt.Sprintf("invalid lt value: %s", part[3:])) + if isArray { + panic(fmt.Sprintf("unknown validation: %s", part)) + } else { + if idx != -1 { + switch valName { + case "required": + case "min": + validateStr.WriteString(fmt.Sprintf(".min(%s)", valValue)) + case "max": + validateStr.WriteString(fmt.Sprintf(".max(%s)", valValue)) + case "len": + validateStr.WriteString(fmt.Sprintf(".length(%s)", valValue)) + case "eq": + validateStr.WriteString(fmt.Sprintf(".length(%s)", valValue)) + case "ne": + refines = append(refines, fmt.Sprintf(".refine((val) => val.length !== %s)", valValue)) + case "gt": + val, err := strconv.Atoi(valValue) + if err != nil || val < 0 { + panic(fmt.Sprintf("invalid gt value: %s", valValue)) + } + validateStr.WriteString(fmt.Sprintf(".min(%d)", val+1)) + case "gte": + validateStr.WriteString(fmt.Sprintf(".min(%s)", valValue)) + case "lt": + val, err := strconv.Atoi(valValue) + if err != nil || val <= 0 { + panic(fmt.Sprintf("invalid lt value: %s", valValue)) + } + validateStr.WriteString(fmt.Sprintf(".max(%d)", val-1)) + case "lte": + validateStr.WriteString(fmt.Sprintf(".max(%s)", valValue)) + + default: + panic(fmt.Sprintf("unknown validation: %s", part)) } - validateStr.WriteString(fmt.Sprintf(".max(%d)", val-1)) - } else if strings.HasPrefix(part, "lte=") { - validateStr.WriteString(fmt.Sprintf(".max(%s)", part[4:])) } else { - panic(fmt.Sprintf("unknown validation: %s", part)) + switch valName { + case "omitempty": + case "required": + case "dive": + goto forEnd + + default: + panic(fmt.Sprintf("unknown validation: %s", part)) + } } } } +forEnd: + if isArray { + validateStr.WriteString(fmt.Sprintf(".length(%d)", t.Len())) + } + + for _, refine := range refines { + validateStr.WriteString(refine) + } + return fmt.Sprintf( "%s.array()%s", c.ConvertType(t.Elem(), getValidateAfterDive(validate), indent), validateStr.String()) @@ -607,40 +723,86 @@ func (c *Converter) convertKeyType(t reflect.Type, validate string) string { func (c *Converter) convertMap(t reflect.Type, validate string, indent int) string { var validateStr strings.Builder - if validate != "" { - parts := strings.Split(validate, ",") + var refines []string + parts := strings.Split(validate, ",") - for _, part := range parts { - part = strings.TrimSpace(part) - if part == "omitempty" { - } else if part == "dive" { - break - } else if part == "required" { - validateStr.WriteString(".refine((val) => Object.keys(val).length > 0, 'Empty map')") - } else if strings.HasPrefix(part, "min=") { - validateStr.WriteString(fmt.Sprintf(".refine((val) => Object.keys(val).length >= %s, 'Map too small')", part[4:])) - } else if strings.HasPrefix(part, "max=") { - validateStr.WriteString(fmt.Sprintf(".refine((val) => Object.keys(val).length <= %s, 'Map too large')", part[4:])) - } else if strings.HasPrefix(part, "len=") { - validateStr.WriteString(fmt.Sprintf(".refine((val) => Object.keys(val).length === %s, 'Map wrong size')", part[4:])) - } else if strings.HasPrefix(part, "eq=") { - validateStr.WriteString(fmt.Sprintf(".refine((val) => Object.keys(val).length === %s, 'Map wrong size')", part[3:])) - } else if strings.HasPrefix(part, "ne=") { - validateStr.WriteString(fmt.Sprintf(".refine((val) => Object.keys(val).length !== %s, 'Map wrong size')", part[3:])) - } else if strings.HasPrefix(part, "gt=") { - validateStr.WriteString(fmt.Sprintf(".refine((val) => Object.keys(val).length > %s, 'Map too small')", part[3:])) - } else if strings.HasPrefix(part, "gte=") { - validateStr.WriteString(fmt.Sprintf(".refine((val) => Object.keys(val).length >= %s, 'Map too small')", part[4:])) - } else if strings.HasPrefix(part, "lt=") { - validateStr.WriteString(fmt.Sprintf(".refine((val) => Object.keys(val).length < %s, 'Map too large')", part[3:])) - } else if strings.HasPrefix(part, "lte=") { - validateStr.WriteString(fmt.Sprintf(".refine((val) => Object.keys(val).length <= %s, 'Map too large')", part[4:])) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + + idx := strings.Index(part, "=") + if idx == 0 || idx == len(part)-1 { + panic(fmt.Sprintf("invalid validation: %s", part)) + } + + var valName string + var valValue string + if idx == -1 { + valName = part + } else { + valName = part[:idx] + valValue = part[idx+1:] + } + + if c.checkIsIgnored(valName) { + continue + } + + if h, ok := c.customTags[valName]; ok { + v := h(c, reflect.TypeOf(0), valValue, 0) + if strings.HasPrefix(v, ".refine") { + refines = append(refines, v) } else { + validateStr.WriteString(v) + } + continue + } + + if idx != -1 { + switch valName { + case "min": + refines = append(refines, fmt.Sprintf(".refine((val) => Object.keys(val).length >= %s, 'Map too small')", valValue)) + case "max": + refines = append(refines, fmt.Sprintf(".refine((val) => Object.keys(val).length <= %s, 'Map too large')", valValue)) + case "len": + refines = append(refines, fmt.Sprintf(".refine((val) => Object.keys(val).length === %s, 'Map wrong size')", valValue)) + case "eq": + refines = append(refines, fmt.Sprintf(".refine((val) => Object.keys(val).length === %s, 'Map wrong size')", valValue)) + case "ne": + refines = append(refines, fmt.Sprintf(".refine((val) => Object.keys(val).length !== %s, 'Map wrong size')", valValue)) + case "gt": + refines = append(refines, fmt.Sprintf(".refine((val) => Object.keys(val).length > %s, 'Map too small')", valValue)) + case "gte": + refines = append(refines, fmt.Sprintf(".refine((val) => Object.keys(val).length >= %s, 'Map too small')", valValue)) + case "lt": + refines = append(refines, fmt.Sprintf(".refine((val) => Object.keys(val).length < %s, 'Map too large')", valValue)) + case "lte": + refines = append(refines, fmt.Sprintf(".refine((val) => Object.keys(val).length <= %s, 'Map too large')", valValue)) + + default: + panic(fmt.Sprintf("unknown validation: %s", part)) + } + } else { + switch valName { + case "omitempty": + case "required": + refines = append(refines, ".refine((val) => Object.keys(val).length > 0, 'Empty map')") + case "dive": + goto forEnd + + default: panic(fmt.Sprintf("unknown validation: %s", part)) } } } +forEnd: + for _, refine := range refines { + validateStr.WriteString(refine) + } + return fmt.Sprintf(`z.record(%s, %s)%s`, c.convertKeyType(t.Key(), getValidateKeys(validate)), c.ConvertType(t.Elem(), getValidateValues(validate), indent), @@ -712,12 +874,8 @@ func getValidateValues(validate string) string { return validateValues } -func (c *Converter) SetIgnores(validations []string) { - c.ignores = validations -} - func (c *Converter) checkIsIgnored(part string) bool { - for _, ignore := range c.ignores { + for _, ignore := range c.ignoreTags { if part == ignore { return true } @@ -729,38 +887,44 @@ func (c *Converter) checkIsIgnored(part string) bool { // could support unusual cases like `validate:"omitempty,min=3,max=5"` func (c *Converter) validateNumber(validate string) string { var validateStr strings.Builder + var refines []string parts := strings.Split(validate, ",") - // eq and ne should be at the end since they output a refine function - sort.SliceStable(parts, func(i, j int) bool { - if strings.HasPrefix(parts[i], "eq") || strings.HasPrefix(parts[i], "len") || - strings.HasPrefix(parts[i], "ne") || strings.HasPrefix(parts[i], "oneof") || - strings.HasPrefix(parts[i], "required") { - return false + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue } - if strings.HasPrefix(parts[j], "eq") || strings.HasPrefix(parts[j], "len") || - strings.HasPrefix(parts[j], "ne") || strings.HasPrefix(parts[j], "oneof") || - strings.HasPrefix(parts[j], "required") { - return true + + idx := strings.Index(part, "=") + if idx == 0 || idx == len(part)-1 { + panic(fmt.Sprintf("invalid validation: %s", part)) } - return i < j - }) - for _, part := range parts { - part = strings.TrimSpace(part) - if c.checkIsIgnored(part) { + var valName string + var valValue string + if idx == -1 { + valName = part + } else { + valName = part[:idx] + valValue = part[idx+1:] + } + + if c.checkIsIgnored(valName) { continue } - if strings.ContainsRune(part, '=') { - idx := strings.Index(part, "=") - if idx == 0 || idx == len(part)-1 { - panic(fmt.Sprintf("invalid validation: %s", part)) + if h, ok := c.customTags[valName]; ok { + v := h(c, reflect.TypeOf(0), valValue, 0) + if strings.HasPrefix(v, ".refine") { + refines = append(refines, v) + } else { + validateStr.WriteString(v) } + continue + } - valName := part[:idx] - valValue := part[idx+1:] - + if idx != -1 { switch valName { case "gt": validateStr.WriteString(fmt.Sprintf(".gt(%s)", valValue)) @@ -771,15 +935,15 @@ func (c *Converter) validateNumber(validate string) string { case "lte", "max": validateStr.WriteString(fmt.Sprintf(".lte(%s)", valValue)) case "eq", "len": - validateStr.WriteString(fmt.Sprintf(".refine((val) => val === %s)", valValue)) + refines = append(refines, fmt.Sprintf(".refine((val) => val === %s)", valValue)) case "ne": - validateStr.WriteString(fmt.Sprintf(".refine((val) => val !== %s)", valValue)) + refines = append(refines, fmt.Sprintf(".refine((val) => val !== %s)", valValue)) case "oneof": vals := strings.Fields(valValue) if len(vals) == 0 { panic(fmt.Sprintf("invalid oneof validation: %s", part)) } - validateStr.WriteString(fmt.Sprintf(".refine((val) => [%s].includes(val))", strings.Join(vals, ", "))) + refines = append(refines, fmt.Sprintf(".refine((val) => [%s].includes(val))", strings.Join(vals, ", "))) default: panic(fmt.Sprintf("unknown validation: %s", part)) @@ -788,46 +952,61 @@ func (c *Converter) validateNumber(validate string) string { switch part { case "omitempty": case "required": - validateStr.WriteString(".refine((val) => val !== 0)") + refines = append(refines, ".refine((val) => val !== 0)") + default: panic(fmt.Sprintf("unknown validation: %s", part)) } } } + for _, refine := range refines { + validateStr.WriteString(refine) + } + return validateStr.String() } func (c *Converter) validateString(validate string) string { var validateStr strings.Builder + var refines []string parts := strings.Split(validate, ",") - // eq and ne should be at the end since they output a refine function - sort.SliceStable(parts, func(i, j int) bool { - if strings.HasPrefix(parts[i], "eq") || strings.HasPrefix(parts[i], "ne") { - return false + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue } - if strings.HasPrefix(parts[j], "eq") || strings.HasPrefix(parts[j], "ne") { - return true + + idx := strings.Index(part, "=") + if idx == 0 || idx == len(part)-1 { + panic(fmt.Sprintf("invalid validation: %s", part)) } - return i < j - }) - for _, part := range parts { - part = strings.TrimSpace(part) - if c.checkIsIgnored(part) { + var valName string + var valValue string + if idx == -1 { + valName = part + } else { + valName = part[:idx] + valValue = part[idx+1:] + } + + if c.checkIsIgnored(valName) { continue } - // We handle the parts which have = separately - if strings.ContainsRune(part, '=') { - idx := strings.Index(part, "=") - if idx == 0 || idx == len(part)-1 { - panic(fmt.Sprintf("invalid validation: %s", part)) - } - valName := part[:idx] - valValue := part[idx+1:] + if h, ok := c.customTags[valName]; ok { + v := h(c, reflect.TypeOf(""), validate, 0) + if strings.HasPrefix(v, ".refine") { + refines = append(refines, v) + } else { + validateStr.WriteString(v) + } + continue + } + if idx != -1 { switch valName { case "oneof": vals := splitParamsRegex.FindAllString(part[6:], -1) @@ -868,9 +1047,9 @@ func (c *Converter) validateString(validate string) string { case "startswith": validateStr.WriteString(fmt.Sprintf(".startsWith(\"%s\")", valValue)) case "eq": - validateStr.WriteString(fmt.Sprintf(".refine((val) => val === \"%s\")", valValue)) + refines = append(refines, fmt.Sprintf(".refine((val) => val === \"%s\")", valValue)) case "ne": - validateStr.WriteString(fmt.Sprintf(".refine((val) => val !== \"%s\")", valValue)) + refines = append(refines, fmt.Sprintf(".refine((val) => val !== \"%s\")", valValue)) default: panic(fmt.Sprintf("unknown validation: %s", part)) @@ -919,13 +1098,13 @@ func (c *Converter) validateString(validate string) string { case "boolean": validateStr.WriteString(".enum(['true', 'false'])") case "lowercase": - validateStr.WriteString(".refine((val) => val === val.toLowerCase())") + refines = append(refines, ".refine((val) => val === val.toLowerCase())") case "number": validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", numberRegexString)) case "numeric": validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", numericRegexString)) case "uppercase": - validateStr.WriteString(".refine((val) => val === val.toUpperCase())") + refines = append(refines, ".refine((val) => val === val.toUpperCase())") case "base64": validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", base64RegexString)) case "mongodb": @@ -945,7 +1124,7 @@ func (c *Converter) validateString(validate string) string { // //jsonSchema.parse(data); - validateStr.WriteString(".refine((val) => { try { JSON.parse(val); return true } catch { return false } })") + refines = append(refines, ".refine((val) => { try { JSON.parse(val); return true } catch { return false } })") case "jwt": validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", jWTRegexString)) case "latitude": @@ -985,6 +1164,10 @@ func (c *Converter) validateString(validate string) string { } } + for _, refine := range refines { + validateStr.WriteString(refine) + } + return validateStr.String() } diff --git a/zod_test.go b/zod_test.go index f6c50df..f52a0fe 100644 --- a/zod_test.go +++ b/zod_test.go @@ -3,6 +3,7 @@ package zen import ( "fmt" "reflect" + "strings" "testing" "time" @@ -111,7 +112,7 @@ func TestStructSimplePrefix(t *testing.T) { export type BotUser = z.infer `, - StructToZodSchemaWithPrefix("Bot", User{})) + StructToZodSchema(User{}, WithPrefix("Bot"))) } func TestNestedStruct(t *testing.T) { @@ -1774,7 +1775,7 @@ func TestConvertSlice(t *testing.T) { type Whim struct { Wham *Foo } - c := NewConverter(map[string]CustomFn{}) + c := NewConverter() types := []interface{}{ Zip{}, Whim{}, @@ -1975,11 +1976,11 @@ export type User = z.infer } func TestCustom(t *testing.T) { - c := NewConverter(map[string]CustomFn{ + c := NewConverter(WithCustomTypes(map[string]CustomFn{ "github.com/hypersequent/zen.Decimal": func(c *Converter, t reflect.Type, validate string, i int) string { return "z.string()" }, - }) + })) type Decimal struct { Value int @@ -2084,7 +2085,7 @@ type PairMap[K comparable, T any, U any] struct { } func TestGenerics(t *testing.T) { - c := NewConverter(nil) + c := NewConverter() c.AddType(StringIntPair{}) c.AddType(GenericPair[int, bool]{}) c.AddType(PairMap[string, int, bool]{}) @@ -2132,3 +2133,52 @@ export type TestSliceFieldsStruct = z.infer `, StructToZodSchema(TestSliceFieldsStruct{})) } + +func TestCustomTag(t *testing.T) { + type SortParams struct { + Order *string `json:"order,omitempty" validate:"omitempty,oneof=asc desc"` + Field *string `json:"field,omitempty"` + } + + type Request struct { + SortParams `validate:"sortFields=title address age dob"` + PaginationParams struct { + Start *int `json:"start,omitempty" validate:"omitempty,gt=0"` + End *int `json:"end,omitempty" validate:"omitempty,gt=0"` + } `validate:"pageParams"` + Search *string `json:"search,omitempty" validate:"identifier"` + } + + customTagHandlers := map[string]CustomFn{ + "identifier": func(c *Converter, t reflect.Type, validate string, i int) string { + return ".refine((val) => !val || /^[a-z0-9_]*$/.test(val), 'Invalid search identifier')" + }, + "pageParams": func(c *Converter, t reflect.Type, validate string, i int) string { + return ".refine((val) => !val.start || !val.end || val.start < val.end, 'Start should be less than end')" + }, + "sortFields": func(c *Converter, t reflect.Type, validate string, i int) string { + sortFields := strings.Split(validate, " ") + for i := range sortFields { + sortFields[i] = fmt.Sprintf("'%s'", sortFields[i]) + } + return fmt.Sprintf(".extend({field: z.enum([%s])})", strings.Join(sortFields, ", ")) + }, + } + + assert.Equal(t, `export const SortParamsSchema = z.object({ + order: z.enum(["asc", "desc"] as const).optional(), + field: z.string().optional(), +}) +export type SortParams = z.infer + +export const RequestSchema = z.object({ + PaginationParams: z.object({ + start: z.number().gt(0).optional(), + end: z.number().gt(0).optional(), + }).refine((val) => !val.start || !val.end || val.start < val.end, 'Start should be less than end'), + search: z.string().refine((val) => !val || /^[a-z0-9_]*$/.test(val), 'Invalid search identifier').optional(), +}).merge(SortParamsSchema.extend({field: z.enum(['title', 'address', 'age', 'dob'])})) +export type Request = z.infer + +`, NewConverter(WithCustomTags(customTagHandlers)).Convert(Request{})) +} From 1c52350887c47cbe9e9c8ec70c648b2b59305e64 Mon Sep 17 00:00:00 2001 From: Himanshu Rai Date: Thu, 30 Jan 2025 19:51:41 +0530 Subject: [PATCH 2/5] Upgrade go and golangci-lint version --- .github/workflows/ci.yml | 6 +++--- .golangci.yml | 4 ++-- Makefile | 2 +- custom/decimal/go.mod | 2 +- custom/optional/go.mod | 2 +- go.mod | 2 +- go.work | 2 +- 7 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dfe1186..6f85e80 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,9 +21,9 @@ jobs: with: fetch-depth: 2 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: - go-version: '^1.21.3' + go-version: '^1.23.5' - run: go version - name: Install gofumpt @@ -38,7 +38,7 @@ jobs: - name: golangci-lint uses: golangci/golangci-lint-action@v6 with: - version: v1.59 + version: v1.63 args: --verbose --timeout=3m - name: Test diff --git a/.golangci.yml b/.golangci.yml index 32cbe51..50ca4ed 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,3 +1,3 @@ -run: - skip-files: +issues: + exclude-files: - regexes.go diff --git a/Makefile b/Makefile index 84079f5..ff90a84 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ GOCMD=GO111MODULE=on go linters-install: @golangci-lint --version >/dev/null 2>&1 || { \ echo "installing linting tools..."; \ - curl -sfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh| sh -s v1.52.2; \ + curl -sfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh| sh -s v1.63.4; \ } lint: linters-install diff --git a/custom/decimal/go.mod b/custom/decimal/go.mod index 91267cb..ce2cba2 100644 --- a/custom/decimal/go.mod +++ b/custom/decimal/go.mod @@ -1,6 +1,6 @@ module github.com/hypersequent/zen/custom/decimal -go 1.21 +go 1.23 replace github.com/hypersequent/zen => ../.. diff --git a/custom/optional/go.mod b/custom/optional/go.mod index 0cda467..7a3e146 100644 --- a/custom/optional/go.mod +++ b/custom/optional/go.mod @@ -1,6 +1,6 @@ module github.com/hypersequent/zen/custom/optional -go 1.21 +go 1.23 replace github.com/hypersequent/zen => ../.. diff --git a/go.mod b/go.mod index ffc111a..bc34ef1 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/hypersequent/zen -go 1.21 +go 1.23 require github.com/stretchr/testify v1.8.3 diff --git a/go.work b/go.work index cf33607..ec8dc31 100644 --- a/go.work +++ b/go.work @@ -1,4 +1,4 @@ -go 1.21 +go 1.23 use ( . From d47fde23d8168b9f530ef2697c67706dfd9d0125 Mon Sep 17 00:00:00 2001 From: Himanshu Rai Date: Thu, 13 Feb 2025 10:02:17 +0530 Subject: [PATCH 3/5] Fix comments --- README.md | 10 +- custom/decimal/decimal_test.go | 2 +- custom/optional/optional_test.go | 2 +- notes | 50 +++++++ zod.go | 232 ++++++++++--------------------- zod_test.go | 8 +- 6 files changed, 132 insertions(+), 172 deletions(-) create mode 100644 notes diff --git a/README.md b/README.md index deaa1ba..0dd095d 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ type Tree struct { fmt.Print(zen.StructToZodSchema(Tree{})) // We can also use create a converter and convert multiple types together -c := zen.NewConverter() +c := zen.NewConverterWithOpts() // Generic types are also supported type GenericPair[T any, U any] struct { @@ -114,7 +114,7 @@ export type PairMapStringIntBool = z.infer - Then using go templates and passing these struct names as input, we generate go code that is later used to generate the zod schemas. ```go.tmpl - converter := zen.NewConverter(make(map[string]zen.CustomFn)) + converter := zen.NewConverterWithOpts(make(map[string]zen.CustomFn)) {{range .TypesToGenerate}} converter.AddType(types.{{.}}{}) @@ -247,7 +247,7 @@ customTagHandlers := map[string]zen.CustomFn{ }, } opt := zen.WithCustomTags(customTagHandlers) -c := zen.NewConverter(opt) +c := zen.NewConverterWithOpts(opt) c.Convert(Request{}) ``` @@ -285,7 +285,7 @@ To ensure safety, `zen` will panic if it encounters unknown validation tags. If ```go opt := zen.WithIgnoreTags("identifier") -c := zen.NewConverter(opt) +c := zen.NewConverterWithOpts(opt) ``` ## Custom Types @@ -300,7 +300,7 @@ customTypeHandlers := map[string]zen.CustomFn{ }, } opt := zen.WithCustomTypes(customTypeHandlers) -c := zen.NewConverter(opt) +c := zen.NewConverterWithOpts(opt) c.Convert(User{ Money decimal.Decimal diff --git a/custom/decimal/decimal_test.go b/custom/decimal/decimal_test.go index 25790be..0433a0c 100644 --- a/custom/decimal/decimal_test.go +++ b/custom/decimal/decimal_test.go @@ -14,7 +14,7 @@ func TestCustom(t *testing.T) { opt := zen.WithCustomTypes(map[string]zen.CustomFn{ customDecimal.DecimalType: customDecimal.DecimalFunc, }) - c := zen.NewConverter(opt) + c := zen.NewConverterWithOpts(opt) type User struct { Money decimal.Decimal diff --git a/custom/optional/optional_test.go b/custom/optional/optional_test.go index 36c7c02..476f880 100644 --- a/custom/optional/optional_test.go +++ b/custom/optional/optional_test.go @@ -13,7 +13,7 @@ func TestCustom(t *testing.T) { opt := zen.WithCustomTypes(map[string]zen.CustomFn{ customoptional.OptionalType: customoptional.OptionalFunc, }) - c := zen.NewConverter(opt) + c := zen.NewConverterWithOpts(opt) type Profile struct { Bio string diff --git a/notes b/notes new file mode 100644 index 0000000..e0ce08b --- /dev/null +++ b/notes @@ -0,0 +1,50 @@ +func (c *Converter) convertStructTopLevel(t reflect.Type) string +func (c *Converter) getType(t reflect.Type, name string, indent int) string + +func (c *Converter) convertStruct(input reflect.Type, indent int) string +func (c *Converter) getTypeStruct(input reflect.Type, indent int) string + +func (c *Converter) convertField(f reflect.StructField, indent int, optional, nullable, anonymous bool) (string, bool) +func (c *Converter) getTypeField(f reflect.StructField, indent int, optional, nullable bool) string + +func (c *Converter) convertSliceAndArray(t reflect.Type, name, validate string, indent int) string +func (c *Converter) getTypeSliceAndArray(t reflect.Type, name string, indent int) string + +func (c *Converter) convertMap(t reflect.Type, name, validate string, indent int) string +func (c *Converter) getTypeMap(t reflect.Type, name string, indent int) string + +func (c *Converter) validateString(validate string) string +func (c *Converter) validateNumber(validate string) string + +type Converter struct { + prefix string + structs int + outputs map[string]entry + custom map[string]CustomFn + stack []meta +} + +type entry struct { + order int + data string +} + +type meta struct { + Name string + SelfRef bool +} + +// name, generic, validate, indent +type CustomFn func(*Converter, reflect.Type, string, string, string, int) string + +type byOrder []entry + +func (a byOrder) Len() int { return len(a) } +func (a byOrder) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a byOrder) Less(i, j int) bool { return a[i].order < a[j].order } + +func (c *Converter) AddType(input interface{}) +func (c *Converter) Convert(input interface{}) string +func (c *Converter) ConvertSlice(inputs []interface{}) string +func (c *Converter) ConvertType(t reflect.Type, name string, validate string, indent int) string +func (c *Converter) Export() string diff --git a/zod.go b/zod.go index 22e45cb..33ccb66 100644 --- a/zod.go +++ b/zod.go @@ -49,8 +49,8 @@ func WithIgnoreTags(ignores ...string) Opt { } } -// NewConverter initializes and returns a new converter instance. -func NewConverter(opts ...Opt) *Converter { +// NewConverterWithOpts initializes and returns a new converter instance. +func NewConverterWithOpts(opts ...Opt) *Converter { c := &Converter{ prefix: "", customTypes: make(map[string]CustomFn), @@ -66,6 +66,19 @@ func NewConverter(opts ...Opt) *Converter { return c } +// NewConverter initializes and returns a new converter instance. The custom handler +// function map should be keyed on the fully qualified type name (excluding generic +// type arguments), ie. package.typename. +func NewConverter(customTypes map[string]CustomFn) Converter { + c := Converter{ + prefix: "", + outputs: make(map[string]entry), + customTypes: customTypes, + } + + return c +} + // AddType converts a struct type to corresponding zod schema. AddType can be called // multiple times, followed by Export to get the corresponding zod schemas. func (c *Converter) AddType(input interface{}) { @@ -110,7 +123,7 @@ func (c *Converter) ConvertSlice(inputs []interface{}) string { // StructToZodSchema returns zod schema corresponding to a struct type. func StructToZodSchema(input interface{}, opts ...Opt) string { - return NewConverter(opts...).Convert(input) + return NewConverterWithOpts(opts...).Convert(input) } var typeMapping = map[reflect.Kind]string{ @@ -398,36 +411,8 @@ func (c *Converter) ConvertType(t reflect.Type, validate string, indent int) str } for _, part := range parts { - part = strings.TrimSpace(part) - if part == "" { - continue - } - - idx := strings.Index(part, "=") - if idx == 0 || idx == len(part)-1 { - panic(fmt.Sprintf("invalid validation: %s", part)) - } - - var valName string - var valValue string - if idx == -1 { - valName = part - } else { - valName = part[:idx] - valValue = part[idx+1:] - } - - if c.checkIsIgnored(valName) { - continue - } - - if h, ok := c.customTags[valName]; ok { - v := h(c, reflect.TypeOf(0), valValue, 0) - if strings.HasPrefix(v, ".refine") { - refines = append(refines, v) - } else { - validateStr.WriteString(v) - } + valName, _, done := c.preprocessValidationTagPart(part, &refines, &validateStr) + if done { continue } @@ -584,44 +569,17 @@ func (c *Converter) convertSliceAndArray(t reflect.Type, validate string, indent parts := strings.Split(validateCurrent, ",") isArray := t.Kind() == reflect.Array +forParts: for _, part := range parts { - part = strings.TrimSpace(part) - if part == "" { - continue - } - - idx := strings.Index(part, "=") - if idx == 0 || idx == len(part)-1 { - panic(fmt.Sprintf("invalid validation: %s", part)) - } - - var valName string - var valValue string - if idx == -1 { - valName = part - } else { - valName = part[:idx] - valValue = part[idx+1:] - } - - if c.checkIsIgnored(valName) { - continue - } - - if h, ok := c.customTags[valName]; ok { - v := h(c, reflect.TypeOf(0), valValue, 0) - if strings.HasPrefix(v, ".refine") { - refines = append(refines, v) - } else { - validateStr.WriteString(v) - } + valName, valValue, done := c.preprocessValidationTagPart(part, &refines, &validateStr) + if done { continue } if isArray { panic(fmt.Sprintf("unknown validation: %s", part)) } else { - if idx != -1 { + if valValue != "" { switch valName { case "required": case "min": @@ -659,7 +617,7 @@ func (c *Converter) convertSliceAndArray(t reflect.Type, validate string, indent case "omitempty": case "required": case "dive": - goto forEnd + break forParts default: panic(fmt.Sprintf("unknown validation: %s", part)) @@ -668,7 +626,6 @@ func (c *Converter) convertSliceAndArray(t reflect.Type, validate string, indent } } -forEnd: if isArray { validateStr.WriteString(fmt.Sprintf(".length(%d)", t.Len())) } @@ -726,41 +683,14 @@ func (c *Converter) convertMap(t reflect.Type, validate string, indent int) stri var refines []string parts := strings.Split(validate, ",") +forParts: for _, part := range parts { - part = strings.TrimSpace(part) - if part == "" { + valName, valValue, done := c.preprocessValidationTagPart(part, &refines, &validateStr) + if done { continue } - idx := strings.Index(part, "=") - if idx == 0 || idx == len(part)-1 { - panic(fmt.Sprintf("invalid validation: %s", part)) - } - - var valName string - var valValue string - if idx == -1 { - valName = part - } else { - valName = part[:idx] - valValue = part[idx+1:] - } - - if c.checkIsIgnored(valName) { - continue - } - - if h, ok := c.customTags[valName]; ok { - v := h(c, reflect.TypeOf(0), valValue, 0) - if strings.HasPrefix(v, ".refine") { - refines = append(refines, v) - } else { - validateStr.WriteString(v) - } - continue - } - - if idx != -1 { + if valValue != "" { switch valName { case "min": refines = append(refines, fmt.Sprintf(".refine((val) => Object.keys(val).length >= %s, 'Map too small')", valValue)) @@ -790,7 +720,7 @@ func (c *Converter) convertMap(t reflect.Type, validate string, indent int) stri case "required": refines = append(refines, ".refine((val) => Object.keys(val).length > 0, 'Empty map')") case "dive": - goto forEnd + break forParts default: panic(fmt.Sprintf("unknown validation: %s", part)) @@ -798,7 +728,6 @@ func (c *Converter) convertMap(t reflect.Type, validate string, indent int) stri } } -forEnd: for _, refine := range refines { validateStr.WriteString(refine) } @@ -891,40 +820,12 @@ func (c *Converter) validateNumber(validate string) string { parts := strings.Split(validate, ",") for _, part := range parts { - part = strings.TrimSpace(part) - if part == "" { - continue - } - - idx := strings.Index(part, "=") - if idx == 0 || idx == len(part)-1 { - panic(fmt.Sprintf("invalid validation: %s", part)) - } - - var valName string - var valValue string - if idx == -1 { - valName = part - } else { - valName = part[:idx] - valValue = part[idx+1:] - } - - if c.checkIsIgnored(valName) { - continue - } - - if h, ok := c.customTags[valName]; ok { - v := h(c, reflect.TypeOf(0), valValue, 0) - if strings.HasPrefix(v, ".refine") { - refines = append(refines, v) - } else { - validateStr.WriteString(v) - } + valName, valValue, done := c.preprocessValidationTagPart(part, &refines, &validateStr) + if done { continue } - if idx != -1 { + if valValue != "" { switch valName { case "gt": validateStr.WriteString(fmt.Sprintf(".gt(%s)", valValue)) @@ -973,40 +874,12 @@ func (c *Converter) validateString(validate string) string { parts := strings.Split(validate, ",") for _, part := range parts { - part = strings.TrimSpace(part) - if part == "" { + valName, valValue, done := c.preprocessValidationTagPart(part, &refines, &validateStr) + if done { continue } - idx := strings.Index(part, "=") - if idx == 0 || idx == len(part)-1 { - panic(fmt.Sprintf("invalid validation: %s", part)) - } - - var valName string - var valValue string - if idx == -1 { - valName = part - } else { - valName = part[:idx] - valValue = part[idx+1:] - } - - if c.checkIsIgnored(valName) { - continue - } - - if h, ok := c.customTags[valName]; ok { - v := h(c, reflect.TypeOf(""), validate, 0) - if strings.HasPrefix(v, ".refine") { - refines = append(refines, v) - } else { - validateStr.WriteString(v) - } - continue - } - - if idx != -1 { + if valValue != "" { switch valName { case "oneof": vals := splitParamsRegex.FindAllString(part[6:], -1) @@ -1171,6 +1044,43 @@ func (c *Converter) validateString(validate string) string { return validateStr.String() } +func (c *Converter) preprocessValidationTagPart(part string, refines *[]string, validateStr *strings.Builder) (string, string, bool) { + part = strings.TrimSpace(part) + if part == "" { + return "", "", true + } + + idx := strings.Index(part, "=") + if idx == 0 || idx == len(part)-1 { + panic(fmt.Sprintf("invalid validation: %s", part)) + } + + var valName string + var valValue string + if idx == -1 { + valName = part + } else { + valName = part[:idx] + valValue = part[idx+1:] + } + + if c.checkIsIgnored(valName) { + return "", "", true + } + + if h, ok := c.customTags[valName]; ok { + v := h(c, reflect.TypeOf(0), valValue, 0) + if strings.HasPrefix(v, ".refine") { + *refines = append(*refines, v) + } else { + (*validateStr).WriteString(v) + } + return "", "", true + } + + return valName, valValue, false +} + func isNullable(field reflect.StructField) bool { validateCurrent := getValidateCurrent(field.Tag.Get("validate")) diff --git a/zod_test.go b/zod_test.go index f52a0fe..202dd22 100644 --- a/zod_test.go +++ b/zod_test.go @@ -1775,7 +1775,7 @@ func TestConvertSlice(t *testing.T) { type Whim struct { Wham *Foo } - c := NewConverter() + c := NewConverterWithOpts() types := []interface{}{ Zip{}, Whim{}, @@ -1976,7 +1976,7 @@ export type User = z.infer } func TestCustom(t *testing.T) { - c := NewConverter(WithCustomTypes(map[string]CustomFn{ + c := NewConverterWithOpts(WithCustomTypes(map[string]CustomFn{ "github.com/hypersequent/zen.Decimal": func(c *Converter, t reflect.Type, validate string, i int) string { return "z.string()" }, @@ -2085,7 +2085,7 @@ type PairMap[K comparable, T any, U any] struct { } func TestGenerics(t *testing.T) { - c := NewConverter() + c := NewConverterWithOpts() c.AddType(StringIntPair{}) c.AddType(GenericPair[int, bool]{}) c.AddType(PairMap[string, int, bool]{}) @@ -2180,5 +2180,5 @@ export const RequestSchema = z.object({ }).merge(SortParamsSchema.extend({field: z.enum(['title', 'address', 'age', 'dob'])})) export type Request = z.infer -`, NewConverter(WithCustomTags(customTagHandlers)).Convert(Request{})) +`, NewConverterWithOpts(WithCustomTags(customTagHandlers)).Convert(Request{})) } From 22a7f6ba0c0c24094c7d8a1fc3cdb611d2ba8d69 Mon Sep 17 00:00:00 2001 From: Himanshu Rai Date: Thu, 13 Feb 2025 10:17:30 +0530 Subject: [PATCH 4/5] Add support for new omitzero json tag --- zod.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/zod.go b/zod.go index 33ccb66..0d0f4b0 100644 --- a/zod.go +++ b/zod.go @@ -1095,11 +1095,13 @@ func isNullable(field reflect.StructField) bool { return false } + jsonTag := field.Tag.Get("json") + // pointers can be nil, which are mapped to null in JS/TS. if field.Type.Kind() == reflect.Ptr { - // However, if a pointer field is tagged with "omitempty", it usually cannot be exported as "null" - // since nil is a pointer's empty value. - if strings.Contains(field.Tag.Get("json"), "omitempty") { + // However, if a pointer field is tagged with "omitempty"/"omitzero", it usually cannot be exported + // as "null" since nil is a pointer's empty/zero value. + if strings.Contains(jsonTag, "omitempty") || strings.Contains(jsonTag, "omitzero") { // Unless it is a pointer to a slice, a map, a pointer, or an interface // because values with those types can themselves be nil and will be exported as "null". k := field.Type.Elem().Kind() @@ -1112,7 +1114,7 @@ func isNullable(field reflect.StructField) bool { // nil slices and maps are exported as null so these types are usually nullable if field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Map { // unless there are also optional in which case they are no longer nullable - return !strings.Contains(field.Tag.Get("json"), "omitempty") + return !strings.Contains(jsonTag, "omitempty") && !strings.Contains(jsonTag, "omitzero") } return false @@ -1158,8 +1160,9 @@ func isOptional(field reflect.StructField) bool { return false } - // Otherwise, omitempty zero-values are omitted and are mapped to undefined in JS/TS. - return strings.Contains(field.Tag.Get("json"), "omitempty") + // Otherwise, omitempty/omitzero zero-values are omitted and are mapped to undefined in JS/TS. + jsonTag := field.Tag.Get("json") + return strings.Contains(jsonTag, "omitempty") || strings.Contains(jsonTag, "omitzero") } func indentation(level int) string { From d47c99c63c10c8204ee5ab0178313d3e2515f9be Mon Sep 17 00:00:00 2001 From: Andrian Budantsov Date: Fri, 14 Feb 2025 08:42:13 +0400 Subject: [PATCH 5/5] minor change --- zod.go | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/zod.go b/zod.go index 0d0f4b0..9e56a2d 100644 --- a/zod.go +++ b/zod.go @@ -66,17 +66,12 @@ func NewConverterWithOpts(opts ...Opt) *Converter { return c } -// NewConverter initializes and returns a new converter instance. The custom handler -// function map should be keyed on the fully qualified type name (excluding generic -// type arguments), ie. package.typename. +// Deprecated: NewConverter is deprecated. Use NewConverterWithOpts(WithCustomTypes(customTypes)) instead. +// Example: +// +// converter := NewConverterWithOpts(WithCustomTypes(customTypes)) func NewConverter(customTypes map[string]CustomFn) Converter { - c := Converter{ - prefix: "", - outputs: make(map[string]entry), - customTypes: customTypes, - } - - return c + return *NewConverterWithOpts(WithCustomTypes(customTypes)) } // AddType converts a struct type to corresponding zod schema. AddType can be called