Skip to content

Commit

Permalink
add support for required fields and groups
Browse files Browse the repository at this point in the history
  • Loading branch information
jhump committed Aug 28, 2024
1 parent 00011dd commit ebbe537
Show file tree
Hide file tree
Showing 11 changed files with 490 additions and 38 deletions.
2 changes: 1 addition & 1 deletion buf.gen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ managed:
except:
- buf.build/bufbuild/protovalidate
plugins:
- plugin: buf.build/protocolbuffers/go:v1.33.0
- plugin: buf.build/protocolbuffers/go:v1.34.0
out: internal/gen/proto
opt: paths=source_relative
81 changes: 52 additions & 29 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ type UnmarshalOptions struct {
protoregistry.ExtensionTypeResolver
}

// If AllowPartial is set, input for messages that will result in missing
// required fields will not return an error.
AllowPartial bool

// DiscardUnknown specifies whether to discard unknown fields instead of
// returning an error.
DiscardUnknown bool
Expand Down Expand Up @@ -243,32 +247,32 @@ func (u *unmarshaler) unmarshalScalar(
node *yaml.Node,
field protoreflect.FieldDescriptor,
forKey bool,
) protoreflect.Value {
) (protoreflect.Value, bool) {
switch field.Kind() {
case protoreflect.BoolKind:
return protoreflect.ValueOfBool(u.unmarshalBool(node, forKey))
return protoreflect.ValueOfBool(u.unmarshalBool(node, forKey)), true
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
return protoreflect.ValueOfInt32(int32(u.unmarshalInteger(node, 32)))
return protoreflect.ValueOfInt32(int32(u.unmarshalInteger(node, 32))), true
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
return protoreflect.ValueOfInt64(u.unmarshalInteger(node, 64))
return protoreflect.ValueOfInt64(u.unmarshalInteger(node, 64)), true
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
return protoreflect.ValueOfUint32(uint32(u.unmarshalUnsigned(node, 32)))
return protoreflect.ValueOfUint32(uint32(u.unmarshalUnsigned(node, 32))), true
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
return protoreflect.ValueOfUint64(u.unmarshalUnsigned(node, 64))
return protoreflect.ValueOfUint64(u.unmarshalUnsigned(node, 64)), true
case protoreflect.FloatKind:
return protoreflect.ValueOfFloat32(float32(u.unmarshalFloat(node, 32)))
return protoreflect.ValueOfFloat32(float32(u.unmarshalFloat(node, 32))), true
case protoreflect.DoubleKind:
return protoreflect.ValueOfFloat64(u.unmarshalFloat(node, 64))
return protoreflect.ValueOfFloat64(u.unmarshalFloat(node, 64)), true
case protoreflect.StringKind:
u.checkKind(node, yaml.ScalarNode)
return protoreflect.ValueOfString(node.Value)
return protoreflect.ValueOfString(node.Value), true
case protoreflect.BytesKind:
return protoreflect.ValueOfBytes(u.unmarshalBytes(node))
return protoreflect.ValueOfBytes(u.unmarshalBytes(node)), true
case protoreflect.EnumKind:
return protoreflect.ValueOfEnum(u.unmarshalEnum(node, field))
return protoreflect.ValueOfEnum(u.unmarshalEnum(node, field)), true
default:
u.addErrorf(node, "unimplemented scalar type %v", field.Kind())
return protoreflect.Value{}
return protoreflect.Value{}, false
}
}

Expand Down Expand Up @@ -550,10 +554,12 @@ func (u *unmarshaler) unmarshalField(node *yaml.Node, field protoreflect.FieldDe
u.unmarshalList(node, field, message.ProtoReflect().Mutable(field).List())
case field.IsMap():
u.unmarshalMap(node, field, message.ProtoReflect().Mutable(field).Map())
case field.Kind() == protoreflect.MessageKind:
case field.Message() != nil:
u.unmarshalMessage(node, message.ProtoReflect().Mutable(field).Message().Interface(), false)
default:
message.ProtoReflect().Set(field, u.unmarshalScalar(node, field, false))
if val, ok := u.unmarshalScalar(node, field, false); ok {
message.ProtoReflect().Set(field, val)
}
}
}

Expand All @@ -569,29 +575,41 @@ func (u *unmarshaler) unmarshalList(node *yaml.Node, field protoreflect.FieldDes
}
default:
for _, itemNode := range node.Content {
list.Append(u.unmarshalScalar(itemNode, field, false))
val, ok := u.unmarshalScalar(itemNode, field, false)
if !ok {
continue
}
list.Append(val)
}
}
}
}

// Unmarshal the map, with explicit handling for maps to messages.
func (u *unmarshaler) unmarshalMap(node *yaml.Node, field protoreflect.FieldDescriptor, mapVal protoreflect.Map) {
if u.checkKind(node, yaml.MappingNode) {
mapKeyField := field.MapKey()
mapValueField := field.MapValue()
for i := 1; i < len(node.Content); i += 2 {
keyNode := node.Content[i-1]
valueNode := node.Content[i]
mapKey := u.unmarshalScalar(keyNode, mapKeyField, true)
switch mapValueField.Kind() {
case protoreflect.MessageKind, protoreflect.GroupKind:
mapValue := mapVal.NewValue()
u.unmarshalMessage(valueNode, mapValue.Message().Interface(), false)
mapVal.Set(mapKey.MapKey(), mapValue)
default:
mapVal.Set(mapKey.MapKey(), u.unmarshalScalar(valueNode, mapValueField, false))
if !u.checkKind(node, yaml.MappingNode) {
return
}
mapKeyField := field.MapKey()
mapValueField := field.MapValue()
for i := 1; i < len(node.Content); i += 2 {
keyNode := node.Content[i-1]
valueNode := node.Content[i]
mapKey, ok := u.unmarshalScalar(keyNode, mapKeyField, true)
if !ok {
continue
}
switch mapValueField.Kind() {
case protoreflect.MessageKind, protoreflect.GroupKind:
mapValue := mapVal.NewValue()
u.unmarshalMessage(valueNode, mapValue.Message().Interface(), false)
mapVal.Set(mapKey.MapKey(), mapValue)
default:
val, ok := u.unmarshalScalar(valueNode, mapValueField, false)
if !ok {
continue
}
mapVal.Set(mapKey.MapKey(), val)
}
}
}
Expand Down Expand Up @@ -648,6 +666,11 @@ func (u *unmarshaler) unmarshalMessage(node *yaml.Node, message proto.Message, f
return
}
u.unmarshalMessageFields(node, message, forAny)
if len(u.errors) == 0 && !u.options.AllowPartial {
if err := proto.CheckInitialized(message); err != nil {
u.addError(node, err)
}
}
}

func (u *unmarshaler) unmarshalMessageFields(node *yaml.Node, message proto.Message, forAny bool) {
Expand Down
38 changes: 38 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ import (
"testing"

testv1 "github.com/bufbuild/protoyaml-go/internal/gen/proto/buf/protoyaml/test/v1"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/durationpb"
)

Expand Down Expand Up @@ -98,6 +100,42 @@ func TestExtension(t *testing.T) {
require.Equal(t, "hi", proto.GetExtension(actual, testv1.E_P2TStringExt))
}

func TestEditions(t *testing.T) {
t.Parallel()

expected := &testv1.EditionsTest{
Name: proto.String("foobar"),
Nested: &testv1.EditionsTest_Nested{
Ids: []int64{0, 1, 1, 2, 3, 5, 8},
},
Enum: testv1.OpenEnum_ZERO,
}
actual := &testv1.EditionsTest{}
data := []byte(`
name: "foobar"
enum: ZERO
nested:
ids: [0, 1, 1, 2, 3, 5, 8]`)
err := Unmarshal(data, actual)
require.NoError(t, err)
require.Empty(t, cmp.Diff(expected, actual, protocmp.Transform()))
}

func TestRequiredFields(t *testing.T) {
t.Parallel()

actual := &testv1.EditionsTest{}
err := Unmarshal([]byte(`enum: ZERO`), actual)
require.ErrorContains(t, err, "required field buf.protoyaml.test.v1.EditionsTest.name not set")

err = UnmarshalOptions{AllowPartial: true}.Unmarshal([]byte(`enum: ZERO`), actual)
require.NoError(t, err)
expected := &testv1.EditionsTest{
Enum: testv1.OpenEnum_ZERO,
}
require.Empty(t, cmp.Diff(expected, actual, protocmp.Transform()))
}

func TestDiscardUnknown(t *testing.T) {
t.Parallel()

Expand Down
2 changes: 1 addition & 1 deletion internal/gen/proto/buf/protoyaml/test/v1/const.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit ebbe537

Please sign in to comment.