diff --git a/api/orchestration.go b/api/orchestration.go index 2e73662..1ef2c7c 100644 --- a/api/orchestration.go +++ b/api/orchestration.go @@ -23,6 +23,29 @@ var ( EmptyInstanceID = InstanceID("") ) +type CreateOrchestrationAction = protos.CreateOrchestrationAction + +const ( + REUSE_ID_ACTION_ERROR CreateOrchestrationAction = protos.CreateOrchestrationAction_ERROR + REUSE_ID_ACTION_IGNORE CreateOrchestrationAction = protos.CreateOrchestrationAction_IGNORE + REUSE_ID_ACTION_TERMINATE CreateOrchestrationAction = protos.CreateOrchestrationAction_TERMINATE +) + +type OrchestrationStatus = protos.OrchestrationStatus + +const ( + RUNTIME_STATUS_RUNNING OrchestrationStatus = protos.OrchestrationStatus_ORCHESTRATION_STATUS_RUNNING + RUNTIME_STATUS_COMPLETED OrchestrationStatus = protos.OrchestrationStatus_ORCHESTRATION_STATUS_COMPLETED + RUNTIME_STATUS_CONTINUED_AS_NEW OrchestrationStatus = protos.OrchestrationStatus_ORCHESTRATION_STATUS_CONTINUED_AS_NEW + RUNTIME_STATUS_FAILED OrchestrationStatus = protos.OrchestrationStatus_ORCHESTRATION_STATUS_FAILED + RUNTIME_STATUS_CANCELED OrchestrationStatus = protos.OrchestrationStatus_ORCHESTRATION_STATUS_CANCELED + RUNTIME_STATUS_TERMINATED OrchestrationStatus = protos.OrchestrationStatus_ORCHESTRATION_STATUS_TERMINATED + RUNTIME_STATUS_PENDING OrchestrationStatus = protos.OrchestrationStatus_ORCHESTRATION_STATUS_PENDING + RUNTIME_STATUS_SUSPENDED OrchestrationStatus = protos.OrchestrationStatus_ORCHESTRATION_STATUS_SUSPENDED +) + +type OrchestrationIdReusePolicy = protos.OrchestrationIdReusePolicy + // InstanceID is a unique identifier for an orchestration instance. type InstanceID string diff --git a/backend/client.go b/backend/client.go index 4aae340..b38cfd8 100644 --- a/backend/client.go +++ b/backend/client.go @@ -59,7 +59,7 @@ func (c *backendClient) ScheduleNewOrchestration(ctx context.Context, orchestrat tc := helpers.TraceContextFromSpan(span) e := helpers.NewExecutionStartedEvent(req.Name, req.InstanceId, req.Input, nil, tc) - if err := c.be.CreateOrchestrationInstance(ctx, e); err != nil { + if err := c.be.CreateOrchestrationInstance(ctx, e, WithOrchestrationIdReusePolicy(req.OrchestrationIdReusePolicy)); err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) return api.EmptyInstanceID, fmt.Errorf("failed to start orchestration: %w", err) diff --git a/backend/sqlite/sqlite.go b/backend/sqlite/sqlite.go index 5699c02..3ce4dbc 100644 --- a/backend/sqlite/sqlite.go +++ b/backend/sqlite/sqlite.go @@ -434,15 +434,7 @@ func (be *sqliteBackend) CreateOrchestrationInstance(ctx context.Context, e *bac return nil } -func buildStatusSet(statuses []protos.OrchestrationStatus) map[protos.OrchestrationStatus]struct{} { - statusSet := make(map[protos.OrchestrationStatus]struct{}) - for _, status := range statuses { - statusSet[status] = struct{}{} - } - return statusSet -} - -func (be *sqliteBackend) createOrchestrationInstanceInternal(ctx context.Context, e *backend.HistoryEvent, tx *sql.Tx, opts ...backend.OrchestrationIdReusePolicyOptions) (string, error) { +func (be *sqliteBackend) createOrchestrationInstanceInternal(ctx context.Context, e *backend.HistoryEvent, tx *sql.Tx, opts ...backend.OrchestrationIdReusePolicyOptions) (string, error) { if e == nil { return "", errors.New("HistoryEvent must be non-nil") } else if e.Timestamp == nil { @@ -466,6 +458,7 @@ func (be *sqliteBackend) createOrchestrationInstanceInternal(ctx context.Context return "", err } + // instance with same ID already exists if rows <= 0 { return instanceID, be.handleInstanceExists(ctx, tx, startEvent, policy, e) } @@ -500,7 +493,7 @@ func insertOrIgnoreInstanceTableInternal(ctx context.Context, tx *sql.Tx, e *bac if err != nil { return -1, fmt.Errorf("failed to count the rows affected: %w", err) } - return rows, nil; + return rows, nil } func (be *sqliteBackend) handleInstanceExists(ctx context.Context, tx *sql.Tx, startEvent *protos.ExecutionStartedEvent, policy *protos.OrchestrationIdReusePolicy, e *backend.HistoryEvent) error { @@ -518,10 +511,8 @@ func (be *sqliteBackend) handleInstanceExists(ctx context.Context, tx *sql.Tx, s return fmt.Errorf("failed to scan the Instances table result: %w", err) } - // instance already exists - targetStatusValues := buildStatusSet(policy.OperationStatus) // status not match, return instance duplicate error - if _, ok := targetStatusValues[helpers.FromRuntimeStatusString(*runtimeStatus)]; !ok { + if !isStatusMatch(policy.OperationStatus, helpers.FromRuntimeStatusString(*runtimeStatus)) { return api.ErrDuplicateInstance } @@ -533,8 +524,8 @@ func (be *sqliteBackend) handleInstanceExists(ctx context.Context, tx *sql.Tx, s return api.ErrIgnoreInstance case protos.CreateOrchestrationAction_TERMINATE: // terminate existing instance - if err := be.cleanupOrchestrationStateInternal(ctx, tx, api.InstanceID(startEvent.OrchestrationInstance.InstanceId),false); err != nil { - return err + if err := be.cleanupOrchestrationStateInternal(ctx, tx, api.InstanceID(startEvent.OrchestrationInstance.InstanceId), false); err != nil { + return fmt.Errorf("failed to cleanup orchestration status: %w", err) } // create a new instance var rows int64 @@ -552,7 +543,16 @@ func (be *sqliteBackend) handleInstanceExists(ctx context.Context, tx *sql.Tx, s return api.ErrDuplicateInstance } -func (be *sqliteBackend) cleanupOrchestrationStateInternal(ctx context.Context, tx *sql.Tx, id api.InstanceID, onlyIfCompleted bool) error { +func isStatusMatch(statuses []protos.OrchestrationStatus, runtimeStatus protos.OrchestrationStatus) bool { + for _, status := range statuses { + if status == runtimeStatus { + return true + } + } + return false +} + +func (be *sqliteBackend) cleanupOrchestrationStateInternal(ctx context.Context, tx *sql.Tx, id api.InstanceID, requireCompleted bool) error { row := tx.QueryRowContext(ctx, "SELECT 1 FROM Instances WHERE [InstanceID] = ?", string(id)) if err := row.Err(); err != nil { return fmt.Errorf("failed to query for instance existence: %w", err) @@ -565,13 +565,13 @@ func (be *sqliteBackend) cleanupOrchestrationStateInternal(ctx context.Context, return fmt.Errorf("failed to scan instance existence: %w", err) } - if onlyIfCompleted { + if requireCompleted { // purge orchestration in ['COMPLETED', 'FAILED', 'TERMINATED'] dbResult, err := tx.ExecContext(ctx, "DELETE FROM Instances WHERE [InstanceID] = ? AND [RuntimeStatus] IN ('COMPLETED', 'FAILED', 'TERMINATED')", string(id)) if err != nil { return fmt.Errorf("failed to delete from the Instances table: %w", err) } - + rowsAffected, err := dbResult.RowsAffected() if err != nil { return fmt.Errorf("failed to get rows affected in Instances delete operation: %w", err) diff --git a/tests/grpc/grpc_test.go b/tests/grpc/grpc_test.go index daa6dca..8f0de3c 100644 --- a/tests/grpc/grpc_test.go +++ b/tests/grpc/grpc_test.go @@ -228,7 +228,7 @@ func Test_Grpc_Terminate_Recursive(t *testing.T) { } } -func Test_Grpc_ReuseInstanceIDSkip(t *testing.T) { +func Test_Grpc_ReuseInstanceIDIgnore(t *testing.T) { delayTime := 2 * time.Second r := task.NewTaskRegistry() r.AddOrchestratorN("SingleActivity", func(ctx *task.OrchestrationContext) (any, error) { @@ -252,13 +252,9 @@ func Test_Grpc_ReuseInstanceIDSkip(t *testing.T) { cancelListener := startGrpcListener(t, r) defer cancelListener() instanceID := api.InstanceID("SKIP_IF_RUNNING_OR_COMPLETED") - reuseIdPolicy := &protos.OrchestrationIdReusePolicy{ - Action: protos.CreateOrchestrationAction_IGNORE, - OperationStatus: []protos.OrchestrationStatus{ - protos.OrchestrationStatus_ORCHESTRATION_STATUS_RUNNING, - protos.OrchestrationStatus_ORCHESTRATION_STATUS_COMPLETED, - protos.OrchestrationStatus_ORCHESTRATION_STATUS_PENDING, - }, + reuseIdPolicy := &api.OrchestrationIdReusePolicy{ + Action: api.REUSE_ID_ACTION_IGNORE, + OperationStatus: []api.OrchestrationStatus{api.RUNTIME_STATUS_RUNNING, api.RUNTIME_STATUS_COMPLETED, api.RUNTIME_STATUS_PENDING}, } id, err := grpcClient.ScheduleNewOrchestration(ctx, "SingleActivity", api.WithInput("世界"), api.WithInstanceID(instanceID)) @@ -304,13 +300,9 @@ func Test_Grpc_ReuseInstanceIDTerminate(t *testing.T) { cancelListener := startGrpcListener(t, r) defer cancelListener() instanceID := api.InstanceID("TERMINATE_IF_RUNNING_OR_COMPLETED") - reuseIdPolicy := &protos.OrchestrationIdReusePolicy{ - Action: protos.CreateOrchestrationAction_TERMINATE, - OperationStatus: []protos.OrchestrationStatus{ - protos.OrchestrationStatus_ORCHESTRATION_STATUS_RUNNING, - protos.OrchestrationStatus_ORCHESTRATION_STATUS_COMPLETED, - protos.OrchestrationStatus_ORCHESTRATION_STATUS_PENDING, - }, + reuseIdPolicy := &api.OrchestrationIdReusePolicy{ + Action: api.REUSE_ID_ACTION_TERMINATE, + OperationStatus: []api.OrchestrationStatus{api.RUNTIME_STATUS_RUNNING, api.RUNTIME_STATUS_COMPLETED, api.RUNTIME_STATUS_PENDING}, } id, err := grpcClient.ScheduleNewOrchestration(ctx, "SingleActivity", api.WithInput("世界"), api.WithInstanceID(instanceID)) @@ -332,7 +324,7 @@ func Test_Grpc_ReuseInstanceIDTerminate(t *testing.T) { assert.True(t, pivotTime.Before(metadata.CreatedAt)) } -func Test_Grpc_ReuseInstanceIDThrow(t *testing.T) { +func Test_Grpc_ReuseInstanceIDError(t *testing.T) { delayTime := 4 * time.Second r := task.NewTaskRegistry() r.AddOrchestratorN("SingleActivity", func(ctx *task.OrchestrationContext) (any, error) { diff --git a/tests/orchestrations_test.go b/tests/orchestrations_test.go index 87da470..173d33e 100644 --- a/tests/orchestrations_test.go +++ b/tests/orchestrations_test.go @@ -826,6 +826,144 @@ func Test_RecreateCompletedOrchestration(t *testing.T) { ) } +func Test_SingleActivity_ReuseInstanceIDIgnore(t *testing.T) { + // Registration + r := task.NewTaskRegistry() + r.AddOrchestratorN("SingleActivity", func(ctx *task.OrchestrationContext) (any, error) { + var input string + if err := ctx.GetInput(&input); err != nil { + return nil, err + } + var output string + err := ctx.CallActivity("SayHello", task.WithActivityInput(input)).Await(&output) + return output, err + }) + r.AddActivityN("SayHello", func(ctx task.ActivityContext) (any, error) { + var name string + if err := ctx.GetInput(&name); err != nil { + return nil, err + } + return fmt.Sprintf("Hello, %s!", name), nil + }) + + // Initialization + ctx := context.Background() + client, worker := initTaskHubWorker(ctx, r) + defer worker.Shutdown(ctx) + + instanceID := api.InstanceID("IGNORE_IF_RUNNING_OR_COMPLETED") + reuseIdPolicy := &api.OrchestrationIdReusePolicy{ + Action: api.REUSE_ID_ACTION_IGNORE, + OperationStatus: []api.OrchestrationStatus{api.RUNTIME_STATUS_RUNNING, api.RUNTIME_STATUS_COMPLETED, api.RUNTIME_STATUS_PENDING}, + } + + // Run the orchestration + id, err := client.ScheduleNewOrchestration(ctx, "SingleActivity", api.WithInput("世界"), api.WithInstanceID(instanceID)) + require.NoError(t, err) + // wait orchestration to start + client.WaitForOrchestrationStart(ctx, id) + pivotTime := time.Now() + // schedule again, it should ignore creating the new orchestration + id, err = client.ScheduleNewOrchestration(ctx, "SingleActivity", api.WithInput("World"), api.WithInstanceID(id), api.WithOrchestrationIdReusePolicy(reuseIdPolicy)) + require.NoError(t, err) + timeoutCtx, cancelTimeout := context.WithTimeout(ctx, 30*time.Second) + defer cancelTimeout() + metadata, err := client.WaitForOrchestrationCompletion(timeoutCtx, id) + require.NoError(t, err) + assert.Equal(t, true, metadata.IsComplete()) + // the first orchestration should complete as the second one is ignored + assert.Equal(t, `"Hello, 世界!"`, metadata.SerializedOutput) + // assert the orchestration created timestamp + assert.True(t, pivotTime.After(metadata.CreatedAt)) +} + +func Test_SingleActivity_ReuseInstanceIDTerminate(t *testing.T) { + // Registration + r := task.NewTaskRegistry() + r.AddOrchestratorN("SingleActivity", func(ctx *task.OrchestrationContext) (any, error) { + var input string + if err := ctx.GetInput(&input); err != nil { + return nil, err + } + var output string + err := ctx.CallActivity("SayHello", task.WithActivityInput(input)).Await(&output) + return output, err + }) + r.AddActivityN("SayHello", func(ctx task.ActivityContext) (any, error) { + var name string + if err := ctx.GetInput(&name); err != nil { + return nil, err + } + return fmt.Sprintf("Hello, %s!", name), nil + }) + + // Initialization + ctx := context.Background() + client, worker := initTaskHubWorker(ctx, r) + defer worker.Shutdown(ctx) + + instanceID := api.InstanceID("TERMINATE_IF_RUNNING_OR_COMPLETED") + reuseIdPolicy := &api.OrchestrationIdReusePolicy{ + Action: api.REUSE_ID_ACTION_TERMINATE, + OperationStatus: []api.OrchestrationStatus{api.RUNTIME_STATUS_RUNNING, api.RUNTIME_STATUS_COMPLETED, api.RUNTIME_STATUS_PENDING}, + } + + // Run the orchestration + id, err := client.ScheduleNewOrchestration(ctx, "SingleActivity", api.WithInput("世界"), api.WithInstanceID(instanceID)) + require.NoError(t, err) + // wait orchestration to start + client.WaitForOrchestrationStart(ctx, id) + pivotTime := time.Now() + // schedule again, it should terminate the first orchestration and start a new one + id, err = client.ScheduleNewOrchestration(ctx, "SingleActivity", api.WithInput("World"), api.WithInstanceID(id), api.WithOrchestrationIdReusePolicy(reuseIdPolicy)) + require.NoError(t, err) + timeoutCtx, cancelTimeout := context.WithTimeout(ctx, 30*time.Second) + defer cancelTimeout() + metadata, err := client.WaitForOrchestrationCompletion(timeoutCtx, id) + require.NoError(t, err) + assert.Equal(t, true, metadata.IsComplete()) + // the second orchestration should complete. + assert.Equal(t, `"Hello, World!"`, metadata.SerializedOutput) + // assert the orchestration created timestamp + assert.True(t, pivotTime.Before(metadata.CreatedAt)) +} + +func Test_SingleActivity_ReuseInstanceIDError(t *testing.T) { + // Registration + r := task.NewTaskRegistry() + r.AddOrchestratorN("SingleActivity", func(ctx *task.OrchestrationContext) (any, error) { + var input string + if err := ctx.GetInput(&input); err != nil { + return nil, err + } + var output string + err := ctx.CallActivity("SayHello", task.WithActivityInput(input)).Await(&output) + return output, err + }) + r.AddActivityN("SayHello", func(ctx task.ActivityContext) (any, error) { + var name string + if err := ctx.GetInput(&name); err != nil { + return nil, err + } + return fmt.Sprintf("Hello, %s!", name), nil + }) + + // Initialization + ctx := context.Background() + client, worker := initTaskHubWorker(ctx, r) + defer worker.Shutdown(ctx) + + instanceID := api.InstanceID("ERROR_IF_RUNNING_OR_COMPLETED") + + // Run the orchestration + id, err := client.ScheduleNewOrchestration(ctx, "SingleActivity", api.WithInput("世界"), api.WithInstanceID(instanceID)) + require.NoError(t, err) + id, err = client.ScheduleNewOrchestration(ctx, "SingleActivity", api.WithInput("World"), api.WithInstanceID(id)) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "orchestration instance already exists") + } +} + func initTaskHubWorker(ctx context.Context, r *task.TaskRegistry, opts ...backend.NewTaskWorkerOptions) (backend.TaskHubClient, backend.TaskHubWorker) { // TODO: Switch to options pattern logger := backend.DefaultLogger()