From baefa5cf6e059ea01dc5e9924035bffa6cebaa64 Mon Sep 17 00:00:00 2001 From: Neal Patel Date: Thu, 9 Jan 2025 11:27:28 -0800 Subject: [PATCH] Fix decoding null values on non-pointer fields --- gen/decoder.go | 40 ++++++++++++++++++++++++++++++++-------- tests/basic_test.go | 11 +++++++++-- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/gen/decoder.go b/gen/decoder.go index df1a4bd..2af6aaf 100644 --- a/gen/decoder.go +++ b/gen/decoder.go @@ -63,22 +63,34 @@ func (g *Generator) genTypeDecoder(t reflect.Type, out string, tags fieldTags, i unmarshalerIface := reflect.TypeOf((*easyjson.Unmarshaler)(nil)).Elem() if reflect.PtrTo(t).Implements(unmarshalerIface) { - fmt.Fprintln(g.out, ws+"("+out+").UnmarshalEasyJSON(in)") + fmt.Fprintln(g.out, ws+"if in.IsNull() {") + fmt.Fprintln(g.out, ws+" in.Skip()") + fmt.Fprintln(g.out, ws+"} else {") + fmt.Fprintln(g.out, ws+" ("+out+").UnmarshalEasyJSON(in)") + fmt.Fprintln(g.out, ws+"}") return nil } unmarshalerIface = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() if reflect.PtrTo(t).Implements(unmarshalerIface) { - fmt.Fprintln(g.out, ws+"if data := in.Raw(); in.Ok() {") - fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalJSON(data) )") + fmt.Fprintln(g.out, ws+"if in.IsNull() {") + fmt.Fprintln(g.out, ws+" in.Skip()") + fmt.Fprintln(g.out, ws+"} else {") + fmt.Fprintln(g.out, ws+" if data := in.Raw(); in.Ok() {") + fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalJSON(data) )") + fmt.Fprintln(g.out, ws+" }") fmt.Fprintln(g.out, ws+"}") return nil } unmarshalerIface = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() if reflect.PtrTo(t).Implements(unmarshalerIface) { - fmt.Fprintln(g.out, ws+"if data := in.UnsafeBytes(); in.Ok() {") - fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalText(data) )") + fmt.Fprintln(g.out, ws+"if in.IsNull() {") + fmt.Fprintln(g.out, ws+" in.Skip()") + fmt.Fprintln(g.out, ws+"} else {") + fmt.Fprintln(g.out, ws+" if data := in.UnsafeBytes(); in.Ok() {") + fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalText(data) )") + fmt.Fprintln(g.out, ws+" }") fmt.Fprintln(g.out, ws+"}") return nil } @@ -110,13 +122,21 @@ func (g *Generator) genTypeDecoderNoCheck(t reflect.Type, out string, tags field ws := strings.Repeat(" ", indent) // Check whether type is primitive, needs to be done after interface check. if dec := customDecoders[t.String()]; dec != "" { - fmt.Fprintln(g.out, ws+out+" = "+dec) + fmt.Fprintln(g.out, ws+"if in.IsNull() {") + fmt.Fprintln(g.out, ws+" in.Skip()") + fmt.Fprintln(g.out, ws+"} else {") + fmt.Fprintln(g.out, ws+" "+out+" = "+dec) + fmt.Fprintln(g.out, ws+"}") return nil } else if dec := primitiveStringDecoders[t.Kind()]; dec != "" && tags.asString { if tags.intern && t.Kind() == reflect.String { dec = "in.StringIntern()" } - fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")") + fmt.Fprintln(g.out, ws+"if in.IsNull() {") + fmt.Fprintln(g.out, ws+" in.Skip()") + fmt.Fprintln(g.out, ws+"} else {") + fmt.Fprintln(g.out, ws+" "+out+" = "+g.getType(t)+"("+dec+")") + fmt.Fprintln(g.out, ws+"}") return nil } else if dec := primitiveDecoders[t.Kind()]; dec != "" { if tags.intern && t.Kind() == reflect.String { @@ -125,7 +145,11 @@ func (g *Generator) genTypeDecoderNoCheck(t reflect.Type, out string, tags field if tags.noCopy && t.Kind() == reflect.String { dec = "in.UnsafeString()" } - fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")") + fmt.Fprintln(g.out, ws+"if in.IsNull() {") + fmt.Fprintln(g.out, ws+" in.Skip()") + fmt.Fprintln(g.out, ws+"} else {") + fmt.Fprintln(g.out, ws+" "+out+" = "+g.getType(t)+"("+dec+")") + fmt.Fprintln(g.out, ws+"}") return nil } diff --git a/tests/basic_test.go b/tests/basic_test.go index 190f6d0..5bf93f1 100644 --- a/tests/basic_test.go +++ b/tests/basic_test.go @@ -330,14 +330,21 @@ func TestNil(t *testing.T) { } func TestUnmarshalNull(t *testing.T) { - p := primitiveTypesValue + p := PrimitiveTypes{ + String: str, + Ptr: &str, + } - data := `{"Ptr":null}` + data := `{"String":null,"Ptr":null}` if err := easyjson.Unmarshal([]byte(data), &p); err != nil { t.Errorf("easyjson.Unmarshal() error: %v", err) } + if p.String != str { + t.Errorf("Wanted %q, got %q", str, p.String) + } + if p.Ptr != nil { t.Errorf("Wanted nil, got %q", *p.Ptr) }