Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sdks/go: add string utf-8 check to vet runner for serialization #33949

Merged
merged 3 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sdks/go/pkg/beam/core/runtime/graphx/serialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ func ms2duration(d int64) time.Duration {
return time.Duration(d) * time.Millisecond
}

// encodeFn encodes a graph.Fn into a v1pb.Fn proto message.
// All string fields in the DoFn struct must be UTF-8 compliant. The vet runner
// (--beam_strict) will detect any non-UTF8 strings that would fail during JSON serialization.
// The check will be skipped for subtypes that implement the MarshalJSON and
// UnmarshalJSON interface methods.
func encodeFn(u *graph.Fn) (*v1pb.Fn, error) {
switch {
case u.DynFn != nil:
Expand Down
105 changes: 96 additions & 9 deletions sdks/go/pkg/beam/runners/vet/vet.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ package vet
import (
"bytes"
"context"
"encoding/json"
"fmt"
"reflect"
"strings"
Expand Down Expand Up @@ -91,8 +92,8 @@ func Evaluate(_ context.Context, p *beam.Pipeline) (*Eval, error) {
e := newEval()

e.diag("/**\n")
e.extractFromMultiEdges(edges)
return e, nil
err = e.extractFromMultiEdges(edges)
return e, err
}

func newEval() *Eval {
Expand Down Expand Up @@ -133,22 +134,27 @@ type Eval struct {

// extractFromMultiEdges audits the given pipeline edges so we can determine if
// this pipeline will run without reflection.
func (e *Eval) extractFromMultiEdges(edges []*graph.MultiEdge) {
func (e *Eval) extractFromMultiEdges(edges []*graph.MultiEdge) error {
e.diag("PTransform Audit:\n")
for _, edge := range edges {
switch edge.Op {
case graph.ParDo:
// Gets the ParDo's identifier
e.diagf("pardo %s", edge.Name())
e.extractGraphFn((*graph.Fn)(edge.DoFn))
if err := e.extractGraphFn((*graph.Fn)(edge.DoFn)); err != nil {
return err
}
case graph.Combine:
e.diagf("combine %s", edge.Name())
e.extractGraphFn((*graph.Fn)(edge.CombineFn))
if err := e.extractGraphFn((*graph.Fn)(edge.CombineFn)); err != nil {
return err
}
default:
continue
}
e.diag("\n")
}
return nil
}

// Performant returns whether this pipeline needs additional registrations
Expand Down Expand Up @@ -485,6 +491,73 @@ func (e *Eval) Bytes() []byte {
return e.w.Bytes()
}

// checkStructFieldsUTF8 recursively validates that all string fields in the
// given value are UTF-8 compliant.
// It handles structs, slices, arrays, maps, and individual strings while
// avoiding infinite recursion on circular references.
// The function skips validation for types that implement both json.Marshaler
// and json.Unmarshaler interfaces.
//
// Parameters:
// - v: reflect.Value to check
// - seen: map tracking visited values to prevent infinite recursion
//
// Returns:
// - error if any string field contains invalid UTF-8 encoding, nil otherwise
func (e *Eval) checkStructFieldsUTF8(v reflect.Value, seen map[reflect.Value]bool) error {
if !v.IsValid() || seen[v] {
return nil
}

// Track visited values to prevent infinite recursion on circular references.
seen[v] = true

t := v.Type()

// Skip if type implements JSON marshaling.
_, hasMarshaler := reflect.New(t).Interface().(json.Marshaler)
_, hasUnmarshaler := reflect.New(t).Interface().(json.Unmarshaler)
if hasMarshaler && hasUnmarshaler {
return nil
}

switch t.Kind() {
case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if !field.CanInterface() {
// Skip unexported fields.
continue
}
if err := e.checkStructFieldsUTF8(field, seen); err != nil {
return err
}
}
case reflect.Slice, reflect.Array:
for i := 0; i < v.Len(); i++ {
if err := e.checkStructFieldsUTF8(v.Index(i), seen); err != nil {
return err
}
}
case reflect.Map:
iter := v.MapRange()
for iter.Next() {
if err := e.checkStructFieldsUTF8(iter.Key(), seen); err != nil {
return err
}
if err := e.checkStructFieldsUTF8(iter.Value(), seen); err != nil {
return err
}
}
case reflect.String:
str := v.String()
if !utf8.ValidString(str) {
return fmt.Errorf("non-UTF8 compliant string found: %q", str)
}
}
return nil
}

// We need to take graph.Fns (which can be created from any from graph.NewFn)
// and convert them to all needed function caller signatures,
// and emitters.
Expand All @@ -500,17 +573,29 @@ func (e *Eval) Bytes() []byte {

// extractGraphFn does the analysis of the function and determines what things need generating.
// A single line is used, unless it's a struct, at which point one line per implemented method
// is used.
func (e *Eval) extractGraphFn(fn *graph.Fn) {
// is used. For structs, it also validates UTF-8 compliance of all exported string fields.
func (e *Eval) extractGraphFn(fn *graph.Fn) error {
if fn.DynFn != nil {
// TODO(https://github.com/apache/beam/issues/19401) handle dynamics if necessary (probably not since it's got general function handling)
e.diag(" dynamic function")
return
return nil
}
if fn.Recv != nil {
e.diagf(" struct[[%T]]", fn.Recv)

rt := reflectx.SkipPtr(reflect.TypeOf(fn.Recv)) // We need the value not the pointer that's used.
// We need the value not the pointer that's used.
rt := reflectx.SkipPtr(reflect.TypeOf(fn.Recv))
rv := reflect.ValueOf(fn.Recv)
if rv.Kind() == reflect.Ptr {
rv = rv.Elem()
}

// Add UTF-8 compliance check for struct fields.
seen := make(map[reflect.Value]bool)
if err := e.checkStructFieldsUTF8(rv, seen); err != nil {
return err
}

if tk, ok := runtime.TypeKey(rt); ok {
if t, found := runtime.LookupType(tk); !found {
e.needType(tk, rt)
Expand All @@ -532,6 +617,8 @@ func (e *Eval) extractGraphFn(fn *graph.Fn) {
}
e.extractFuncxFn(fn.Fn)
}

return nil
}

type mthd struct {
Expand Down
50 changes: 49 additions & 1 deletion sdks/go/pkg/beam/runners/vet/vet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,77 @@ package vet

import (
"context"
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/vet/testpipeline"
"strings"
"testing"

"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/vet/testpipeline"

"github.com/apache/beam/sdks/v2/go/pkg/beam"
)

type stringContentDoFn struct {
Name string
}

func (fn *stringContentDoFn) ProcessElement(ctx context.Context, _ []byte) error {
return nil
}

type errorType int

const (
noError errorType = iota
utf8Error
)

func TestEvaluate(t *testing.T) {
tests := []struct {
name string
c func(beam.Scope)
perf, exp, ref, reg bool
errType errorType
errMsg string
}{
{name: "Performant", c: testpipeline.Performant, perf: true},
{name: "FunctionReg", c: testpipeline.FunctionReg, exp: true, ref: true, reg: true},
{name: "ShimNeeded", c: testpipeline.ShimNeeded, ref: true},
{name: "TypeReg", c: testpipeline.TypeReg, ref: true, reg: true},
{
name: "NonUTF8DoFn",
c: func(s beam.Scope) {
fn := &stringContentDoFn{Name: "hello\xFFworld"}
beam.ParDo0(s, fn, beam.Impulse(s))
},
errType: utf8Error,
errMsg: "non-UTF8 compliant string found",
},
{
name: "ValidUTF8DoFn",
c: func(s beam.Scope) {
fn := &stringContentDoFn{Name: "helloworld"}
beam.ParDo0(s, fn, beam.Impulse(s))
},
errType: noError,
ref: true,
reg: true,
},
}

for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
p, s := beam.NewPipelineWithRoot()
test.c(s)
e, err := Evaluate(context.Background(), p)
if test.errType != noError {
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), test.errMsg) {
t.Fatalf("error %q doesn't contain %q", err.Error(), test.errMsg)
}
return
}
if err != nil {
t.Fatalf("failed to evaluate testpipeline.Pipeline: %v", err)
}
Expand Down
Loading