diff --git a/app.go b/app.go index c340e15b7..7d485e790 100644 --- a/app.go +++ b/app.go @@ -504,6 +504,23 @@ func New(opts ...Option) *App { return app } + // At this point, we can run the fx.Evaluates (if any). + // As long as there's at least one evaluate per iteration, + // we'll have to keep unwinding. + // + // Keep evaluating until there are no more evaluates to run. + for app.root.evaluateAll() > 0 { + // TODO: is communicating the number of evalutes the best way? + if app.err != nil { + return app + } + + // TODO: fx.Module inside evaluates needs to build subscopes. + app.root.provideAll() + app.err = multierr.Append(app.err, app.root.decorateAll()) + // TODO: fx.WithLogger allowed inside an evaluate? + } + if err := app.root.invokeAll(); err != nil { app.err = err diff --git a/evaluate.go b/evaluate.go new file mode 100644 index 000000000..80362ab75 --- /dev/null +++ b/evaluate.go @@ -0,0 +1,153 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package fx + +import ( + "fmt" + "reflect" + "strings" + + "go.uber.org/fx/internal/fxreflect" +) + +// Evaluate specifies one or more evaluation functions. +// These are functions that accept dependencies from the graph +// and return an fx.Option. +// They may have the following signatures: +// +// func(...) fx.Option +// func(...) (fx.Option, error) +// +// These functions are run after provides and decorates. +// The resulting options are applied to the graph, +// and may introduce new provides, invokes, decorates, or evaluates. +// +// The effect of this is that parts of the graph can be dynamically generated +// based on dependency values. +// +// For example, a function with a dependency on a configuration struct +// could conditionally provide different implementations based on the value. +// +// fx.Evaluate(func(cfg *Config) fx.Option { +// if cfg.Environment == "production" { +// return fx.Provide(func(*sql.DB) Repository { +// return &sqlRepository{db: db} +// }), +// } else { +// return fx.Provide(func() Repository { +// return &memoryRepository{} +// }) +// } +// }) +// +// This is different from a normal provide that inspects the configuration +// because the dependency on '*sql.DB' is completely absent in the graph +// if the configuration is not "production". +func Evaluate(fns ...any) Option { + return evaluateOption{ + Targets: fns, + Stack: fxreflect.CallerStack(1, 0), + } +} + +type evaluateOption struct { + Targets []any + Stack fxreflect.Stack +} + +func (o evaluateOption) apply(mod *module) { + for _, target := range o.Targets { + mod.evaluates = append(mod.evaluates, evaluate{ + Target: target, + Stack: o.Stack, + }) + } +} + +func (o evaluateOption) String() string { + items := make([]string, len(o.Targets)) + for i, target := range o.Targets { + items[i] = fxreflect.FuncName(target) + } + return fmt.Sprintf("fx.Evaluate(%s)", strings.Join(items, ", ")) +} + +type evaluate struct { + Target any + Stack fxreflect.Stack +} + +func runEvaluate(m *module, e evaluate) (err error) { + target := e.Target + defer func() { + if err != nil { + err = fmt.Errorf("fx.Evaluate(%v) from:\n%+vFailed: %w", target, e.Stack, err) + } + }() + + // target is a function returning (Option, error). + // Use reflection to build a function with the same parameters, + // and invoke that in the container. + targetV := reflect.ValueOf(target) + targetT := targetV.Type() + inTypes := make([]reflect.Type, targetT.NumIn()) + for i := range targetT.NumIn() { + inTypes[i] = targetT.In(i) + } + outTypes := []reflect.Type{reflect.TypeOf((*error)(nil)).Elem()} + + // TODO: better way to extract information from the container + var opt Option + invokeFn := reflect.MakeFunc( + reflect.FuncOf(inTypes, outTypes, false), + func(args []reflect.Value) []reflect.Value { + out := targetV.Call(args) + switch len(out) { + case 2: + if err, _ := out[1].Interface().(error); err != nil { + return []reflect.Value{reflect.ValueOf(err)} + } + + fallthrough + case 1: + opt, _ = out[0].Interface().(Option) + + default: + panic("TODO: validation") + } + + return []reflect.Value{ + reflect.Zero(reflect.TypeOf((*error)(nil)).Elem()), + } + }, + ).Interface() + if err := m.scope.Invoke(invokeFn); err != nil { + return err + } + + if opt == nil { + // Assume no-op. + return nil + } + + opt.apply(m) + return nil +} diff --git a/evaluate_test.go b/evaluate_test.go new file mode 100644 index 000000000..73126b7b2 --- /dev/null +++ b/evaluate_test.go @@ -0,0 +1,114 @@ +// Copyright (c) 2025 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package fx_test + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/fx" + "go.uber.org/fx/fxtest" +) + +func TestEvaluate(t *testing.T) { + t.Run("ProvidesOptions", func(t *testing.T) { + type t1 struct{} + type t2 struct{} + + var evaluated, provided, invoked bool + app := fxtest.New(t, + fx.Evaluate(func() fx.Option { + evaluated = true + return fx.Provide(func() t1 { + provided = true + return t1{} + }) + }), + fx.Provide(func(t1) t2 { return t2{} }), + fx.Invoke(func(t2) { + invoked = true + }), + ) + defer app.RequireStart().RequireStop() + + assert.True(t, evaluated, "Evaluated function was not called") + assert.True(t, provided, "Provided function was not called") + assert.True(t, invoked, "Invoked function was not called") + }) + + t.Run("OptionalDependency", func(t *testing.T) { + type Config struct{ Dev bool } + + newBufWriter := func(b *bytes.Buffer) io.Writer { + return b + } + + newDiscardWriter := func() io.Writer { + return io.Discard + } + + newWriter := func(cfg Config) fx.Option { + if cfg.Dev { + return fx.Provide(newDiscardWriter) + } + + return fx.Provide(newBufWriter) + } + + t.Run("NoDependency", func(t *testing.T) { + var got io.Writer + app := fxtest.New(t, + fx.Evaluate(newWriter), + fx.Provide( + func() *bytes.Buffer { + t.Errorf("unexpected call to *bytes.Buffer") + return nil + }, + ), + fx.Supply(Config{Dev: true}), + fx.Populate(&got), + ) + defer app.RequireStart().RequireStop() + + assert.NotNil(t, got) + _, _ = io.WriteString(got, "hello") + }) + + t.Run("WithDependency", func(t *testing.T) { + var ( + buf bytes.Buffer + got io.Writer + ) + app := fxtest.New(t, + fx.Evaluate(newWriter), + fx.Supply(&buf, Config{Dev: false}), + fx.Populate(&got), + ) + defer app.RequireStart().RequireStop() + + assert.NotNil(t, got) + _, _ = io.WriteString(got, "hello") + assert.Equal(t, "hello", buf.String()) + }) + }) +} diff --git a/module.go b/module.go index 4615ad46a..3e533b231 100644 --- a/module.go +++ b/module.go @@ -125,6 +125,7 @@ type module struct { provides []provide invokes []invoke decorators []decorator + evaluates []evaluate modules []*module app *App log fxevent.Logger @@ -174,6 +175,7 @@ func (m *module) provideAll() { for _, p := range m.provides { m.provide(p) } + m.provides = nil for _, m := range m.modules { m.provideAll() @@ -264,6 +266,7 @@ func (m *module) installAllEventLoggers() { } } m.fallbackLogger = nil + m.logConstructor = nil } else if m.parent != nil { m.log = m.parent.log } @@ -308,6 +311,7 @@ func (m *module) invokeAll() error { return err } } + m.invokes = nil return nil } @@ -334,6 +338,7 @@ func (m *module) decorateAll() error { return err } } + m.decorators = nil for _, m := range m.modules { if err := m.decorateAll(); err != nil { @@ -405,3 +410,24 @@ func (m *module) replace(d decorator) error { }) return err } + +func (m *module) evaluateAll() (count int) { + for _, e := range m.evaluates { + m.evaluate(e) + count++ + } + m.evaluates = nil + + for _, m := range m.modules { + count += m.evaluateAll() + } + + return count +} + +func (m *module) evaluate(e evaluate) { + // TODO: events + if err := runEvaluate(m, e); err != nil { + m.app.err = err + } +}