diff --git a/tests/nexus_workflow_test.go b/tests/nexus_workflow_test.go index aca43e668a2..6e425025eb8 100644 --- a/tests/nexus_workflow_test.go +++ b/tests/nexus_workflow_test.go @@ -31,7 +31,6 @@ import ( "net/http" "slices" "strings" - "sync/atomic" "testing" "time" @@ -2216,37 +2215,30 @@ func (s *NexusWorkflowTestSuite) TestNexusAsyncOperationWithMultipleCallers() { return "hello " + input, nil } - opConflictFail := temporalnexus.NewWorkflowRunOperation( - "opConflictFail", - handlerWf, - func(ctx context.Context, input string, opts nexus.StartOperationOptions) (client.StartWorkflowOptions, error) { - return client.StartWorkflowOptions{ - ID: handlerWorkflowID, - }, nil - }, - ) - svc.Register(opConflictFail) - - opConflictUseExisting := temporalnexus.NewWorkflowRunOperation( - "opConflictUseExisting", + op := temporalnexus.NewWorkflowRunOperation( + "op", handlerWf, func(ctx context.Context, input string, opts nexus.StartOperationOptions) (client.StartWorkflowOptions, error) { + var conflictPolicy enumspb.WorkflowIdConflictPolicy + if input == "conflict-policy-use-existing" { + conflictPolicy = enumspb.WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING + } return client.StartWorkflowOptions{ ID: handlerWorkflowID, - WorkflowIDConflictPolicy: enumspb.WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING, + WorkflowIDConflictPolicy: conflictPolicy, }, nil }, ) - svc.Register(opConflictUseExisting) + svc.MustRegister(op) type CallerWfOutput struct { - CntOk int32 - CntErr int32 + CntOk int + CntErr int } - callerWf := func(ctx workflow.Context, op string, numCalls int) (CallerWfOutput, error) { - var cntOk atomic.Int32 - var cntErr atomic.Int32 + callerWf := func(ctx workflow.Context, input string, numCalls int) (CallerWfOutput, error) { + output := CallerWfOutput{} + var retError error wg := workflow.NewWaitGroup(ctx) execOpCh := workflow.NewChannel(ctx) @@ -2256,22 +2248,31 @@ func (s *NexusWorkflowTestSuite) TestNexusAsyncOperationWithMultipleCallers() { wg.Add(1) workflow.Go(ctx, func(ctx workflow.Context) { defer wg.Done() - fut := client.ExecuteOperation(ctx, op, "caller", workflow.NexusOperationOptions{}) + fut := client.ExecuteOperation(ctx, op, input, workflow.NexusOperationOptions{}) var exec workflow.NexusOperationExecution err := fut.GetNexusOperationExecution().Get(ctx, &exec) execOpCh.Send(ctx, nil) if err != nil { - cntErr.Add(1) + output.CntErr++ var handlerErr *nexus.HandlerError - s.ErrorAs(err, &handlerErr) - s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) + var appErr *temporal.ApplicationError + if !errors.As(err, &handlerErr) { + retError = err + } else if !errors.As(handlerErr, &appErr) { + retError = err + } else if appErr.Type() != "WorkflowExecutionAlreadyStarted" { + retError = err + } return } - cntOk.Add(1) + output.CntOk++ var res string err = fut.Get(ctx, &res) - s.NoError(err) - s.Equal("hello caller", res) + if err != nil { + retError = err + } else if res != "hello "+input { + retError = fmt.Errorf("unexpected result from handler workflow: %q", res) + } }) } @@ -2282,7 +2283,7 @@ func (s *NexusWorkflowTestSuite) TestNexusAsyncOperationWithMultipleCallers() { // signal handler workflow so it will complete workflow.SignalExternalWorkflow(ctx, handlerWorkflowID, "", "terminate", nil).Get(ctx, nil) wg.Wait(ctx) - return CallerWfOutput{CntOk: cntOk.Load(), CntErr: cntErr.Load()}, nil + return output, retError } w.RegisterNexusService(svc) @@ -2292,42 +2293,41 @@ func (s *NexusWorkflowTestSuite) TestNexusAsyncOperationWithMultipleCallers() { defer w.Stop() testCases := []struct { - op string - numCalls int + input string checkOutput func(t *testing.T, numCalls int, res CallerWfOutput) }{ { - op: "opConflictFail", - numCalls: 5, + input: "conflict-policy-fail", checkOutput: func(t *testing.T, numCalls int, res CallerWfOutput) { s.EqualValues(1, res.CntOk) s.EqualValues(numCalls-1, res.CntErr) }, }, { - op: "opConflictUseExisting", - numCalls: 5, + input: "conflict-policy-use-existing", checkOutput: func(t *testing.T, numCalls int, res CallerWfOutput) { s.EqualValues(numCalls, res.CntOk) }, }, } + // number of concurrent Nexus operation calls + numCalls := 5 for _, tc := range testCases { - s.Run(tc.op, func() { + s.Run(tc.input, func() { run, err := s.SdkClient().ExecuteWorkflow( ctx, client.StartWorkflowOptions{ TaskQueue: callerTaskQueue, }, callerWf, - tc.op, - tc.numCalls, + tc.input, + numCalls, ) s.NoError(err) var res CallerWfOutput s.NoError(run.Get(ctx, &res)) - tc.checkOutput(s.T(), tc.numCalls, res) + tc.checkOutput(s.T(), numCalls, res) }) } }