Skip to content

Commit

Permalink
Support proto2 required fields and groups as well as related editions…
Browse files Browse the repository at this point in the history
… features (#45)

Introduces a new check during unmarshaling that all required fields
of a message have been populated. Adds a new `AllowPartial` flag
to `UnmarshalOptions`, which relaxes that check. Also fixes a panic
that can occur when using groups in proto2 or delimited message
encoding in editions.
  • Loading branch information
jhump authored Aug 28, 2024
1 parent 00011dd commit cb54552
Show file tree
Hide file tree
Showing 12 changed files with 498 additions and 42 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ $(BIN):
@mkdir -p $(BIN)

$(BIN)/buf: $(BIN) Makefile
go install github.com/bufbuild/buf/cmd/buf@latest
go install github.com/bufbuild/buf/cmd/buf@v1.36.0

$(BIN)/license-header: $(BIN) Makefile
go install \
github.com/bufbuild/buf/private/pkg/licenseheader/cmd/license-header@latest
github.com/bufbuild/buf/private/pkg/licenseheader/cmd/license-header@v1.36.0

$(BIN)/golangci-lint: $(BIN) Makefile
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.59.0
2 changes: 1 addition & 1 deletion buf.gen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ managed:
except:
- buf.build/bufbuild/protovalidate
plugins:
- plugin: buf.build/protocolbuffers/go:v1.33.0
- plugin: buf.build/protocolbuffers/go:v1.34.0
out: internal/gen/proto
opt: paths=source_relative
86 changes: 56 additions & 30 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ type UnmarshalOptions struct {
protoregistry.ExtensionTypeResolver
}

// If AllowPartial is set, input for messages that will result in missing
// required fields will not return an error.
AllowPartial bool

// DiscardUnknown specifies whether to discard unknown fields instead of
// returning an error.
DiscardUnknown bool
Expand All @@ -71,7 +75,15 @@ func (o UnmarshalOptions) Unmarshal(data []byte, message proto.Message) error {
if err := yaml.Unmarshal(data, &yamlFile); err != nil {
return err
}
return o.unmarshalNode(&yamlFile, message, data)
if err := o.unmarshalNode(&yamlFile, message, data); err != nil {
return err
}
if !o.AllowPartial {
if err := proto.CheckInitialized(message); err != nil {
return err
}
}
return nil
}

// ParseDuration parses a duration string into a durationpb.Duration.
Expand Down Expand Up @@ -243,32 +255,32 @@ func (u *unmarshaler) unmarshalScalar(
node *yaml.Node,
field protoreflect.FieldDescriptor,
forKey bool,
) protoreflect.Value {
) (protoreflect.Value, bool) {
switch field.Kind() {
case protoreflect.BoolKind:
return protoreflect.ValueOfBool(u.unmarshalBool(node, forKey))
return protoreflect.ValueOfBool(u.unmarshalBool(node, forKey)), true
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
return protoreflect.ValueOfInt32(int32(u.unmarshalInteger(node, 32)))
return protoreflect.ValueOfInt32(int32(u.unmarshalInteger(node, 32))), true
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
return protoreflect.ValueOfInt64(u.unmarshalInteger(node, 64))
return protoreflect.ValueOfInt64(u.unmarshalInteger(node, 64)), true
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
return protoreflect.ValueOfUint32(uint32(u.unmarshalUnsigned(node, 32)))
return protoreflect.ValueOfUint32(uint32(u.unmarshalUnsigned(node, 32))), true
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
return protoreflect.ValueOfUint64(u.unmarshalUnsigned(node, 64))
return protoreflect.ValueOfUint64(u.unmarshalUnsigned(node, 64)), true
case protoreflect.FloatKind:
return protoreflect.ValueOfFloat32(float32(u.unmarshalFloat(node, 32)))
return protoreflect.ValueOfFloat32(float32(u.unmarshalFloat(node, 32))), true
case protoreflect.DoubleKind:
return protoreflect.ValueOfFloat64(u.unmarshalFloat(node, 64))
return protoreflect.ValueOfFloat64(u.unmarshalFloat(node, 64)), true
case protoreflect.StringKind:
u.checkKind(node, yaml.ScalarNode)
return protoreflect.ValueOfString(node.Value)
return protoreflect.ValueOfString(node.Value), true
case protoreflect.BytesKind:
return protoreflect.ValueOfBytes(u.unmarshalBytes(node))
return protoreflect.ValueOfBytes(u.unmarshalBytes(node)), true
case protoreflect.EnumKind:
return protoreflect.ValueOfEnum(u.unmarshalEnum(node, field))
return protoreflect.ValueOfEnum(u.unmarshalEnum(node, field)), true
default:
u.addErrorf(node, "unimplemented scalar type %v", field.Kind())
return protoreflect.Value{}
return protoreflect.Value{}, false
}
}

Expand Down Expand Up @@ -550,10 +562,12 @@ func (u *unmarshaler) unmarshalField(node *yaml.Node, field protoreflect.FieldDe
u.unmarshalList(node, field, message.ProtoReflect().Mutable(field).List())
case field.IsMap():
u.unmarshalMap(node, field, message.ProtoReflect().Mutable(field).Map())
case field.Kind() == protoreflect.MessageKind:
case field.Message() != nil:
u.unmarshalMessage(node, message.ProtoReflect().Mutable(field).Message().Interface(), false)
default:
message.ProtoReflect().Set(field, u.unmarshalScalar(node, field, false))
if val, ok := u.unmarshalScalar(node, field, false); ok {
message.ProtoReflect().Set(field, val)
}
}
}

Expand All @@ -569,29 +583,41 @@ func (u *unmarshaler) unmarshalList(node *yaml.Node, field protoreflect.FieldDes
}
default:
for _, itemNode := range node.Content {
list.Append(u.unmarshalScalar(itemNode, field, false))
val, ok := u.unmarshalScalar(itemNode, field, false)
if !ok {
continue
}
list.Append(val)
}
}
}
}

// Unmarshal the map, with explicit handling for maps to messages.
func (u *unmarshaler) unmarshalMap(node *yaml.Node, field protoreflect.FieldDescriptor, mapVal protoreflect.Map) {
if u.checkKind(node, yaml.MappingNode) {
mapKeyField := field.MapKey()
mapValueField := field.MapValue()
for i := 1; i < len(node.Content); i += 2 {
keyNode := node.Content[i-1]
valueNode := node.Content[i]
mapKey := u.unmarshalScalar(keyNode, mapKeyField, true)
switch mapValueField.Kind() {
case protoreflect.MessageKind, protoreflect.GroupKind:
mapValue := mapVal.NewValue()
u.unmarshalMessage(valueNode, mapValue.Message().Interface(), false)
mapVal.Set(mapKey.MapKey(), mapValue)
default:
mapVal.Set(mapKey.MapKey(), u.unmarshalScalar(valueNode, mapValueField, false))
if !u.checkKind(node, yaml.MappingNode) {
return
}
mapKeyField := field.MapKey()
mapValueField := field.MapValue()
for i := 1; i < len(node.Content); i += 2 {
keyNode := node.Content[i-1]
valueNode := node.Content[i]
mapKey, ok := u.unmarshalScalar(keyNode, mapKeyField, true)
if !ok {
continue
}
switch mapValueField.Kind() {
case protoreflect.MessageKind, protoreflect.GroupKind:
mapValue := mapVal.NewValue()
u.unmarshalMessage(valueNode, mapValue.Message().Interface(), false)
mapVal.Set(mapKey.MapKey(), mapValue)
default:
val, ok := u.unmarshalScalar(valueNode, mapValueField, false)
if !ok {
continue
}
mapVal.Set(mapKey.MapKey(), val)
}
}
}
Expand Down
38 changes: 38 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ import (
"testing"

testv1 "github.com/bufbuild/protoyaml-go/internal/gen/proto/buf/protoyaml/test/v1"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/durationpb"
)

Expand Down Expand Up @@ -98,6 +100,42 @@ func TestExtension(t *testing.T) {
require.Equal(t, "hi", proto.GetExtension(actual, testv1.E_P2TStringExt))
}

func TestEditions(t *testing.T) {
t.Parallel()

expected := &testv1.EditionsTest{
Name: proto.String("foobar"),
Nested: &testv1.EditionsTest_Nested{
Ids: []int64{0, 1, 1, 2, 3, 5, 8},
},
Enum: testv1.OpenEnum_OPEN_ENUM_UNSPECIFIED,
}
actual := &testv1.EditionsTest{}
data := []byte(`
name: "foobar"
enum: OPEN_ENUM_UNSPECIFIED
nested:
ids: [0, 1, 1, 2, 3, 5, 8]`)
err := Unmarshal(data, actual)
require.NoError(t, err)
require.Empty(t, cmp.Diff(expected, actual, protocmp.Transform()))
}

func TestRequiredFields(t *testing.T) {
t.Parallel()

actual := &testv1.EditionsTest{}
err := Unmarshal([]byte(`enum: OPEN_ENUM_UNSPECIFIED`), actual)
require.ErrorContains(t, err, "required field buf.protoyaml.test.v1.EditionsTest.name not set")

err = UnmarshalOptions{AllowPartial: true}.Unmarshal([]byte(`enum: OPEN_ENUM_UNSPECIFIED`), actual)
require.NoError(t, err)
expected := &testv1.EditionsTest{
Enum: testv1.OpenEnum_OPEN_ENUM_UNSPECIFIED,
}
require.Empty(t, cmp.Diff(expected, actual, protocmp.Transform()))
}

func TestDiscardUnknown(t *testing.T) {
t.Parallel()

Expand Down
2 changes: 1 addition & 1 deletion internal/gen/proto/buf/protoyaml/test/v1/const.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit cb54552

Please sign in to comment.