From 121fd5019e9929dc6f53513bd1848744db390767 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juho=20M=C3=A4kinen?= Date: Wed, 12 Feb 2025 15:14:09 +1100 Subject: [PATCH] feat(ftl-schema): pre-commit validation (#4382) - Checks that the event is valid against latest known state - If still, when applying the event, it is invalid, returns a generic error to the user. --- backend/schemaservice/events.go | 157 +++++++++++++++++++++---- backend/schemaservice/schemaservice.go | 45 +++---- backend/schemaservice/state.go | 3 +- common/schema/raftevents.go | 1 + internal/raft/cluster.go | 6 +- internal/raft/cluster_test.go | 29 +++++ internal/raft/statemachine.go | 11 ++ 7 files changed, 209 insertions(+), 43 deletions(-) diff --git a/backend/schemaservice/events.go b/backend/schemaservice/events.go index b98f061934..4fdf63c9f3 100644 --- a/backend/schemaservice/events.go +++ b/backend/schemaservice/events.go @@ -46,13 +46,60 @@ func (r *SchemaState) ApplyEvent(ctx context.Context, event schema.Event) error } } -func handleDeploymentRuntimeEvent(t *SchemaState, e *schema.DeploymentRuntimeEvent) error { +// VerifyEvent verifies an event is valid for the given state, without applying it +func (r *SchemaState) VerifyEvent(ctx context.Context, event schema.Event) error { + if err := event.Validate(); err != nil { + return fmt.Errorf("invalid event: %w", err) + } + switch e := event.(type) { + case *schema.DeploymentRuntimeEvent: + return verifyDeploymentRuntimeEvent(r, e) + case *schema.ChangesetCreatedEvent: + return verifyChangesetCreatedEvent(r, e) + case *schema.ChangesetPreparedEvent: + return verifyChangesetPreparedEvent(r, e) + case *schema.ChangesetCommittedEvent: + return verifyChangesetCommittedEvent(r, e) + case *schema.ChangesetDrainedEvent: + return verifyChangesetDrainedEvent(r, e) + case *schema.ChangesetFinalizedEvent: + return verifyChangesetFinalizedEvent(r, e) + case *schema.ChangesetRollingBackEvent: + return verifyChangesetRollingBackEvent(r, e) + case *schema.ChangesetFailedEvent: + return verifyChangesetFailedEvent(r, e) + default: + return fmt.Errorf("unknown event type: %T", e) + } +} + +func verifyDeploymentRuntimeEvent(t *SchemaState, e *schema.DeploymentRuntimeEvent) error { if cs, ok := e.ChangesetKey().Get(); ok { - c, ok := t.changesets[cs] + _, ok := t.changesets[cs] if !ok { return fmt.Errorf("changeset %s not found", cs.String()) } + for _, m := range t.changesets[cs].Modules { + if m.Name == e.DeploymentKey().Payload.Module { + return nil + } + } + } + for _, m := range t.deployments { + if m.Runtime.Deployment.DeploymentKey == e.DeploymentKey() { + return nil + } + } + return fmt.Errorf("deployment %s not found", e.DeploymentKey().String()) +} + +func handleDeploymentRuntimeEvent(t *SchemaState, e *schema.DeploymentRuntimeEvent) error { + if err := verifyDeploymentRuntimeEvent(t, e); err != nil { + return err + } + if cs, ok := e.ChangesetKey().Get(); ok { module := e.DeploymentKey().Payload.Module + c := t.changesets[cs] for _, m := range c.Modules { if m.Name == module { err := e.Payload.ApplyToModule(m) @@ -77,7 +124,7 @@ func handleDeploymentRuntimeEvent(t *SchemaState, e *schema.DeploymentRuntimeEve return fmt.Errorf("deployment %s not found", e.DeploymentKey().String()) } -func handleChangesetCreatedEvent(t *SchemaState, e *schema.ChangesetCreatedEvent) error { +func verifyChangesetCreatedEvent(t *SchemaState, e *schema.ChangesetCreatedEvent) error { if existing := t.changesets[e.Changeset.Key]; existing != nil { return fmt.Errorf("changeset %s already exists ", e.Changeset.Key) } @@ -85,7 +132,6 @@ func handleChangesetCreatedEvent(t *SchemaState, e *schema.ChangesetCreatedEvent existingModules := map[string]key.Changeset{} for _, cs := range t.changesets { if cs.ModulesAreCanonical() { - //TODO: at the moment changesets accumulate forever... for _, mod := range cs.Modules { existingModules[mod.Name] = cs.Key } @@ -130,11 +176,18 @@ func handleChangesetCreatedEvent(t *SchemaState, e *schema.ChangesetCreatedEvent return fmt.Errorf("changeset failed validation %w", errors.Join(problems...)) } } + return nil +} + +func handleChangesetCreatedEvent(t *SchemaState, e *schema.ChangesetCreatedEvent) error { + if err := verifyChangesetCreatedEvent(t, e); err != nil { + return err + } t.changesets[e.Changeset.Key] = e.Changeset return nil } -func handleChangesetPreparedEvent(t *SchemaState, e *schema.ChangesetPreparedEvent) error { +func verifyChangesetPreparedEvent(t *SchemaState, e *schema.ChangesetPreparedEvent) error { changeset, ok := t.changesets[e.Key] if !ok { return fmt.Errorf("changeset %s not found", e.Key) @@ -147,6 +200,14 @@ func handleChangesetPreparedEvent(t *SchemaState, e *schema.ChangesetPreparedEve return fmt.Errorf("deployment %s has no endpoint", dep.Name) } } + return nil +} + +func handleChangesetPreparedEvent(t *SchemaState, e *schema.ChangesetPreparedEvent) error { + if err := verifyChangesetPreparedEvent(t, e); err != nil { + return err + } + changeset := t.changesets[e.Key] changeset.State = schema.ChangesetStatePrepared // TODO: what does this actually mean? Worry about it when we start implementing canaries, but it will be clunky // If everything that cares about canaries needs to scan for prepared changesets @@ -156,7 +217,7 @@ func handleChangesetPreparedEvent(t *SchemaState, e *schema.ChangesetPreparedEve return nil } -func handleChangesetCommittedEvent(ctx context.Context, t *SchemaState, e *schema.ChangesetCommittedEvent) error { +func verifyChangesetCommittedEvent(t *SchemaState, e *schema.ChangesetCommittedEvent) error { changeset, ok := t.changesets[e.Key] if !ok { return fmt.Errorf("changeset %s not found", e.Key) @@ -167,6 +228,15 @@ func handleChangesetCommittedEvent(ctx context.Context, t *SchemaState, e *schem return fmt.Errorf("deployment %s is not in correct state expected %v got %v", dep.Name, schema.DeploymentStateCanary, dep.Runtime.Deployment.State) } } + return nil +} + +func handleChangesetCommittedEvent(ctx context.Context, t *SchemaState, e *schema.ChangesetCommittedEvent) error { + if err := verifyChangesetCommittedEvent(t, e); err != nil { + return err + } + + changeset := t.changesets[e.Key] logger := log.FromContext(ctx) changeset.State = schema.ChangesetStateCommitted for _, dep := range changeset.Modules { @@ -182,8 +252,7 @@ func handleChangesetCommittedEvent(ctx context.Context, t *SchemaState, e *schem return nil } -func handleChangesetDrainedEvent(ctx context.Context, t *SchemaState, e *schema.ChangesetDrainedEvent) error { - logger := log.FromContext(ctx) +func verifyChangesetDrainedEvent(t *SchemaState, e *schema.ChangesetDrainedEvent) error { changeset, ok := t.changesets[e.Key] if !ok { return fmt.Errorf("changeset %s not found", e.Key) @@ -191,34 +260,65 @@ func handleChangesetDrainedEvent(ctx context.Context, t *SchemaState, e *schema. if changeset.State != schema.ChangesetStateCommitted { return fmt.Errorf("changeset %v is not in the correct state", changeset.Key) } + + for _, dep := range changeset.RemovingModules { + if dep.ModRuntime().ModDeployment().State != schema.DeploymentStateDraining && + dep.ModRuntime().ModDeployment().State != schema.DeploymentStateDeProvisioning { + return fmt.Errorf("deployment %s is not in correct state expected %v got %v", dep.Name, schema.DeploymentStateDeProvisioning, dep.Runtime.Deployment.State) + } + } + return nil +} + +func handleChangesetDrainedEvent(ctx context.Context, t *SchemaState, e *schema.ChangesetDrainedEvent) error { + if err := verifyChangesetDrainedEvent(t, e); err != nil { + return err + } + + logger := log.FromContext(ctx) + changeset := t.changesets[e.Key] logger.Debugf("Changeset %s drained", e.Key) for _, dep := range changeset.RemovingModules { if dep.ModRuntime().ModDeployment().State == schema.DeploymentStateDraining { dep.Runtime.Deployment.State = schema.DeploymentStateDeProvisioning - } else if dep.ModRuntime().ModDeployment().State != schema.DeploymentStateDeProvisioning { - return fmt.Errorf("deployment %s is not in correct state expected %v got %v", dep.Name, schema.DeploymentStateDeProvisioning, dep.Runtime.Deployment.State) } } changeset.State = schema.ChangesetStateDrained return nil } -func handleChangesetFinalizedEvent(ctx context.Context, r *SchemaState, e *schema.ChangesetFinalizedEvent) error { - logger := log.FromContext(ctx) - changeset, ok := r.changesets[e.Key] +func verifyChangesetFinalizedEvent(t *SchemaState, e *schema.ChangesetFinalizedEvent) error { + changeset, ok := t.changesets[e.Key] if !ok { return fmt.Errorf("changeset %s not found", e.Key) } if changeset.State != schema.ChangesetStateDrained { return fmt.Errorf("changeset %v is not in the correct state expected %v got %v", changeset.Key, schema.ChangesetStateDrained, changeset.State) } + + for _, dep := range changeset.RemovingModules { + if dep.ModRuntime().ModDeployment().State == schema.DeploymentStateDeProvisioning { + continue + } + if dep.ModRuntime().ModDeployment().State != schema.DeploymentStateDeleted { + return fmt.Errorf("deployment %s is not in correct state expected %v got %v", dep.Name, schema.DeploymentStateDeleted, dep.Runtime.Deployment.State) + } + } + return nil +} + +func handleChangesetFinalizedEvent(ctx context.Context, r *SchemaState, e *schema.ChangesetFinalizedEvent) error { + if err := verifyChangesetFinalizedEvent(r, e); err != nil { + return err + } + + logger := log.FromContext(ctx) + changeset := r.changesets[e.Key] logger.Debugf("Changeset %s de-provisioned", e.Key) for _, dep := range changeset.RemovingModules { if dep.ModRuntime().ModDeployment().State == schema.DeploymentStateDeProvisioning { dep.Runtime.Deployment.State = schema.DeploymentStateDeleted - } else if dep.ModRuntime().ModDeployment().State != schema.DeploymentStateDeleted { - return fmt.Errorf("deployment %s is not in correct state expected %v got %v", dep.Name, schema.DeploymentStateDeleted, dep.Runtime.Deployment.State) } } changeset.State = schema.ChangesetStateFinalized @@ -232,11 +332,20 @@ func handleChangesetFinalizedEvent(ctx context.Context, r *SchemaState, e *schem return nil } -func handleChangesetFailedEvent(t *SchemaState, e *schema.ChangesetFailedEvent) error { - changeset, ok := t.changesets[e.Key] +func verifyChangesetFailedEvent(t *SchemaState, e *schema.ChangesetFailedEvent) error { + _, ok := t.changesets[e.Key] if !ok { return fmt.Errorf("changeset %s not found", e.Key) } + return nil +} + +func handleChangesetFailedEvent(t *SchemaState, e *schema.ChangesetFailedEvent) error { + if err := verifyChangesetFailedEvent(t, e); err != nil { + return err + } + + changeset := t.changesets[e.Key] changeset.State = schema.ChangesetStateFailed //TODO: de-provisioning on failure? delete(t.changesets, changeset.Key) @@ -246,14 +355,22 @@ func handleChangesetFailedEvent(t *SchemaState, e *schema.ChangesetFailedEvent) t.archivedChangesets = nl return nil } -func handleChangesetRollingBackEvent(t *SchemaState, e *schema.ChangesetRollingBackEvent) error { - changeset, ok := t.changesets[e.Key] +func verifyChangesetRollingBackEvent(t *SchemaState, e *schema.ChangesetRollingBackEvent) error { + _, ok := t.changesets[e.Key] if !ok { return fmt.Errorf("changeset %s not found", e.Key) } + return nil +} + +func handleChangesetRollingBackEvent(t *SchemaState, e *schema.ChangesetRollingBackEvent) error { + if err := verifyChangesetRollingBackEvent(t, e); err != nil { + return err + } + + changeset := t.changesets[e.Key] changeset.State = schema.ChangesetStateRollingBack changeset.Error = e.Error - println("ERROR " + e.Error) for _, module := range changeset.Modules { module.Runtime.Deployment.State = schema.DeploymentStateDeProvisioning } diff --git a/backend/schemaservice/schemaservice.go b/backend/schemaservice/schemaservice.go index d8812bce92..b0b53571ef 100644 --- a/backend/schemaservice/schemaservice.go +++ b/backend/schemaservice/schemaservice.go @@ -147,7 +147,7 @@ func (s *Service) UpdateDeploymentRuntime(ctx context.Context, req *connect.Requ } changeset = &cs } - err = s.State.Publish(ctx, EventWrapper{Event: &schema.DeploymentRuntimeEvent{Changeset: changeset, Payload: event}}) + err = s.publishEvent(ctx, &schema.DeploymentRuntimeEvent{Changeset: changeset, Payload: event}) if err != nil { return nil, fmt.Errorf("could not apply event: %w", err) } @@ -200,9 +200,7 @@ func (s *Service) CreateChangeset(ctx context.Context, req *connect.Request[ftlv } // TODO: validate changeset schema with canonical schema - err = s.State.Publish(ctx, EventWrapper{Event: &schema.ChangesetCreatedEvent{ - Changeset: changeset, - }}) + err = s.publishEvent(ctx, &schema.ChangesetCreatedEvent{Changeset: changeset}) if err != nil { return nil, fmt.Errorf("could not create changeset %w", err) } @@ -216,9 +214,7 @@ func (s *Service) PrepareChangeset(ctx context.Context, req *connect.Request[ftl if err != nil { return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("invalid changeset key: %w", err)) } - err = s.State.Publish(ctx, EventWrapper{Event: &schema.ChangesetPreparedEvent{ - Key: changesetKey, - }}) + err = s.publishEvent(ctx, &schema.ChangesetPreparedEvent{Key: changesetKey}) if err != nil { return nil, fmt.Errorf("could not prepare changeset %w", err) } @@ -231,9 +227,7 @@ func (s *Service) CommitChangeset(ctx context.Context, req *connect.Request[ftlv if err != nil { return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("invalid changeset key: %w", err)) } - err = s.State.Publish(ctx, EventWrapper{Event: &schema.ChangesetCommittedEvent{ - Key: changesetKey, - }}) + err = s.publishEvent(ctx, &schema.ChangesetCommittedEvent{Key: changesetKey}) if err != nil { return nil, fmt.Errorf("could not commit changeset %w", err) } @@ -251,9 +245,7 @@ func (s *Service) DrainChangeset(ctx context.Context, req *connect.Request[ftlv1 if err != nil { return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("invalid changeset key: %w", err)) } - err = s.State.Publish(ctx, EventWrapper{Event: &schema.ChangesetDrainedEvent{ - Key: changesetKey, - }}) + err = s.publishEvent(ctx, &schema.ChangesetDrainedEvent{Key: changesetKey}) if err != nil { return nil, fmt.Errorf("could not drain changeset %w", err) } @@ -265,9 +257,7 @@ func (s *Service) FinalizeChangeset(ctx context.Context, req *connect.Request[ft if err != nil { return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("invalid changeset key: %w", err)) } - err = s.State.Publish(ctx, EventWrapper{Event: &schema.ChangesetFinalizedEvent{ - Key: changesetKey, - }}) + err = s.publishEvent(ctx, &schema.ChangesetFinalizedEvent{Key: changesetKey}) if err != nil { return nil, fmt.Errorf("could not de-provision changeset %w", err) } @@ -279,10 +269,10 @@ func (s *Service) RollbackChangeset(ctx context.Context, req *connect.Request[ft if err != nil { return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("invalid changeset key: %w", err)) } - err = s.State.Publish(ctx, EventWrapper{Event: &schema.ChangesetRollingBackEvent{ + err = s.publishEvent(ctx, &schema.ChangesetRollingBackEvent{ Key: changesetKey, Error: req.Msg.Error, - }}) + }) if err != nil { return nil, fmt.Errorf("could not fail changeset %w", err) } @@ -295,15 +285,28 @@ func (s *Service) FailChangeset(ctx context.Context, req *connect.Request[ftlv1. if err != nil { return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("invalid changeset key: %w", err)) } - err = s.State.Publish(ctx, EventWrapper{Event: &schema.ChangesetFailedEvent{ - Key: changesetKey, - }}) + err = s.publishEvent(ctx, &schema.ChangesetFailedEvent{Key: changesetKey}) if err != nil { return nil, fmt.Errorf("could not fail changeset %w", err) } return connect.NewResponse(&ftlv1.FailChangesetResponse{}), nil } +func (s *Service) publishEvent(ctx context.Context, event schema.Event) error { + // Verify the event against the latest known state before publishing + state, err := s.State.View(ctx) + if err != nil { + return fmt.Errorf("failed to get schema state: %w", err) + } + if err := state.VerifyEvent(ctx, event); err != nil { + return fmt.Errorf("invalid event: %w", err) + } + if err := s.State.Publish(ctx, EventWrapper{Event: event}); err != nil { + return fmt.Errorf("failed to publish event: %w", err) + } + return nil +} + func (s *Service) watchModuleChanges(ctx context.Context, subscriptionID string, sendChange func(response *ftlv1.PullSchemaResponse) error) error { logger := log.FromContext(ctx).Scope(subscriptionID) diff --git a/backend/schemaservice/state.go b/backend/schemaservice/state.go index 758e18f584..2d5ce97deb 100644 --- a/backend/schemaservice/state.go +++ b/backend/schemaservice/state.go @@ -18,6 +18,7 @@ import ( "github.com/block/ftl/internal/channels" "github.com/block/ftl/internal/key" "github.com/block/ftl/internal/log" + "github.com/block/ftl/internal/raft" "github.com/block/ftl/internal/statemachine" ) @@ -274,7 +275,7 @@ func (c *schemaStateMachine) Publish(msg EventWrapper) error { // TODO: we need to validate the events before they are // committed to the log logger.Errorf(err, "failed to apply event") - return nil + return raft.ErrInvalidEvent } // Notify all subscribers using broadcaster c.notifier.Notify(c.runningCtx) diff --git a/common/schema/raftevents.go b/common/schema/raftevents.go index 790628bf52..6d38540890 100644 --- a/common/schema/raftevents.go +++ b/common/schema/raftevents.go @@ -14,6 +14,7 @@ import ( //protobuf:export type Event interface { event() + // Validate the event is internally consistent Validate() error // DebugString returns a string representation of the event for debugging purposes DebugString() string diff --git a/internal/raft/cluster.go b/internal/raft/cluster.go index f48326295a..c7064c9c17 100644 --- a/internal/raft/cluster.go +++ b/internal/raft/cluster.go @@ -176,11 +176,15 @@ func (s *ShardHandle[Q, R, E]) Publish(ctx context.Context, msg E) error { if err := s.cluster.withRetry(ctx, s.shardID, s.cluster.runtimeReplicaID, func(ctx context.Context) error { logger.Debugf("Proposing event to shard %d on replica %d", s.shardID, s.cluster.runtimeReplicaID) - _, err := s.cluster.nh.SyncPropose(ctx, s.session, msgBytes) + res, err := s.cluster.nh.SyncPropose(ctx, s.session, msgBytes) if err != nil { return err //nolint:wrapcheck } s.session.ProposalCompleted() + + if res.Value == InvalidEventValue { + return ErrInvalidEvent + } return nil }, dragonboat.ErrShardNotReady, dragonboat.ErrTimeout); err != nil { return fmt.Errorf("failed to propose event: %w", err) diff --git a/internal/raft/cluster_test.go b/internal/raft/cluster_test.go index bf47d7daee..3bcc2ed90f 100644 --- a/internal/raft/cluster_test.go +++ b/internal/raft/cluster_test.go @@ -40,6 +40,11 @@ type IntStateMachine struct { var _ sm.Snapshotting[int64, int64, IntEvent] = &IntStateMachine{} func (s *IntStateMachine) Publish(event IntEvent) error { + if event == 12345 { + // 12345 is not a valid number, so we return an error + return raft.ErrInvalidEvent + } + s.sum += int64(event) return nil } @@ -182,6 +187,30 @@ func TestStateIter(t *testing.T) { assert.True(t, iterops.Contains(changes, 2)) } +func TestInvalidEvents(t *testing.T) { + t.Parallel() + if testing.Short() { + t.SkipNow() + } + ctx := testContext(t) + + _, shards := startClusters(ctx, t, 2, func(b *raft.Builder) []sm.Handle[int64, int64, IntEvent] { + return []sm.Handle[int64, int64, IntEvent]{ + raft.AddShard(ctx, b, 1, &IntStateMachine{}), + } + }) + + // 12345 is not actually a valid number, so we expect an error + err := shards[0][0].Publish(ctx, IntEvent(12345)) + assert.IsError(t, err, raft.ErrInvalidEvent) + + // check that the event did not impact the state machine + assertShardValue(ctx, t, 0, shards[0][0]) + // and we can still publish valid events + assert.NoError(t, shards[0][0].Publish(ctx, IntEvent(1))) + assertShardValue(ctx, t, 1, shards[0][0]) +} + func testBuilder(t *testing.T, addresses []*net.TCPAddr, address string, _ *url.URL) *raft.Builder { members := make([]string, len(addresses)) for i, member := range addresses { diff --git a/internal/raft/statemachine.go b/internal/raft/statemachine.go index e67497e72d..49862a7ef5 100644 --- a/internal/raft/statemachine.go +++ b/internal/raft/statemachine.go @@ -1,6 +1,7 @@ package raft import ( + "errors" "fmt" "io" @@ -9,6 +10,11 @@ import ( sm "github.com/block/ftl/internal/statemachine" ) +// ErrInvalidEvent is returned if we are attempting to publish an invalid event. +var ErrInvalidEvent = errors.New("invalid event") + +const InvalidEventValue = 0x1001 + // stateMachineShim is a shim to convert a typed StateMachine to a dragonboat statemachine.IStateMachine. type stateMachineShim[Q any, R any, E sm.Marshallable, EPtr sm.Unmarshallable[E]] struct { sm sm.Snapshotting[Q, R, E] @@ -44,6 +50,11 @@ func (s *stateMachineShim[Q, R, E, EPtr]) Update(entry statemachine.Entry) (stat return statemachine.Result{}, fmt.Errorf("failed to unmarshal event: %w", err) } if err := s.sm.Publish(to); err != nil { + if errors.Is(err, ErrInvalidEvent) { + return statemachine.Result{ + Value: InvalidEventValue, + }, nil + } return statemachine.Result{}, fmt.Errorf("failed to update state machine: %w", err) }