Skip to content

Commit

Permalink
Fix decoding null values on non-pointer fields
Browse files Browse the repository at this point in the history
  • Loading branch information
neal committed Jan 9, 2025
1 parent 8580601 commit baefa5c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
40 changes: 32 additions & 8 deletions gen/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,34 @@ func (g *Generator) genTypeDecoder(t reflect.Type, out string, tags fieldTags, i

unmarshalerIface := reflect.TypeOf((*easyjson.Unmarshaler)(nil)).Elem()
if reflect.PtrTo(t).Implements(unmarshalerIface) {
fmt.Fprintln(g.out, ws+"("+out+").UnmarshalEasyJSON(in)")
fmt.Fprintln(g.out, ws+"if in.IsNull() {")
fmt.Fprintln(g.out, ws+" in.Skip()")
fmt.Fprintln(g.out, ws+"} else {")
fmt.Fprintln(g.out, ws+" ("+out+").UnmarshalEasyJSON(in)")
fmt.Fprintln(g.out, ws+"}")
return nil
}

unmarshalerIface = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
if reflect.PtrTo(t).Implements(unmarshalerIface) {
fmt.Fprintln(g.out, ws+"if data := in.Raw(); in.Ok() {")
fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalJSON(data) )")
fmt.Fprintln(g.out, ws+"if in.IsNull() {")
fmt.Fprintln(g.out, ws+" in.Skip()")
fmt.Fprintln(g.out, ws+"} else {")
fmt.Fprintln(g.out, ws+" if data := in.Raw(); in.Ok() {")
fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalJSON(data) )")
fmt.Fprintln(g.out, ws+" }")
fmt.Fprintln(g.out, ws+"}")
return nil
}

unmarshalerIface = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
if reflect.PtrTo(t).Implements(unmarshalerIface) {
fmt.Fprintln(g.out, ws+"if data := in.UnsafeBytes(); in.Ok() {")
fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalText(data) )")
fmt.Fprintln(g.out, ws+"if in.IsNull() {")
fmt.Fprintln(g.out, ws+" in.Skip()")
fmt.Fprintln(g.out, ws+"} else {")
fmt.Fprintln(g.out, ws+" if data := in.UnsafeBytes(); in.Ok() {")
fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalText(data) )")
fmt.Fprintln(g.out, ws+" }")
fmt.Fprintln(g.out, ws+"}")
return nil
}
Expand Down Expand Up @@ -110,13 +122,21 @@ func (g *Generator) genTypeDecoderNoCheck(t reflect.Type, out string, tags field
ws := strings.Repeat(" ", indent)
// Check whether type is primitive, needs to be done after interface check.
if dec := customDecoders[t.String()]; dec != "" {
fmt.Fprintln(g.out, ws+out+" = "+dec)
fmt.Fprintln(g.out, ws+"if in.IsNull() {")
fmt.Fprintln(g.out, ws+" in.Skip()")
fmt.Fprintln(g.out, ws+"} else {")
fmt.Fprintln(g.out, ws+" "+out+" = "+dec)
fmt.Fprintln(g.out, ws+"}")
return nil
} else if dec := primitiveStringDecoders[t.Kind()]; dec != "" && tags.asString {
if tags.intern && t.Kind() == reflect.String {
dec = "in.StringIntern()"
}
fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")")
fmt.Fprintln(g.out, ws+"if in.IsNull() {")
fmt.Fprintln(g.out, ws+" in.Skip()")
fmt.Fprintln(g.out, ws+"} else {")
fmt.Fprintln(g.out, ws+" "+out+" = "+g.getType(t)+"("+dec+")")
fmt.Fprintln(g.out, ws+"}")
return nil
} else if dec := primitiveDecoders[t.Kind()]; dec != "" {
if tags.intern && t.Kind() == reflect.String {
Expand All @@ -125,7 +145,11 @@ func (g *Generator) genTypeDecoderNoCheck(t reflect.Type, out string, tags field
if tags.noCopy && t.Kind() == reflect.String {
dec = "in.UnsafeString()"
}
fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")")
fmt.Fprintln(g.out, ws+"if in.IsNull() {")
fmt.Fprintln(g.out, ws+" in.Skip()")
fmt.Fprintln(g.out, ws+"} else {")
fmt.Fprintln(g.out, ws+" "+out+" = "+g.getType(t)+"("+dec+")")
fmt.Fprintln(g.out, ws+"}")
return nil
}

Expand Down
11 changes: 9 additions & 2 deletions tests/basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,21 @@ func TestNil(t *testing.T) {
}

func TestUnmarshalNull(t *testing.T) {
p := primitiveTypesValue
p := PrimitiveTypes{
String: str,
Ptr: &str,
}

data := `{"Ptr":null}`
data := `{"String":null,"Ptr":null}`

if err := easyjson.Unmarshal([]byte(data), &p); err != nil {
t.Errorf("easyjson.Unmarshal() error: %v", err)
}

if p.String != str {
t.Errorf("Wanted %q, got %q", str, p.String)
}

if p.Ptr != nil {
t.Errorf("Wanted nil, got %q", *p.Ptr)
}
Expand Down

0 comments on commit baefa5c

Please sign in to comment.