diff --git a/common/metrics/metric_defs.go b/common/metrics/metric_defs.go index 1de307e9cff..3cff13bc40d 100644 --- a/common/metrics/metric_defs.go +++ b/common/metrics/metric_defs.go @@ -1487,6 +1487,27 @@ var ( VerifyReplicationTasksLatency = NewTimerDef("verify_replication_tasks_latency") VerifyDescribeMutableStateLatency = NewTimerDef("verify_describe_mutable_state_latency") + // Sharded force replication. The sharded ReplicateBatch activity runs + // many executions per invocation, so per-exec timing is the meaningful + // granularity — the batch-level *_tasks_latency timers above scale with + // BatchSize and aren't comparable across configurations. + GenerateReplicationTaskLatency = NewTimerDef("generate_replication_task_latency") + VerifyReplicationTaskLatency = NewTimerDef("verify_replication_task_latency") + // VerifyReplicationTaskBusy counts verify attempts where the passive + // cluster returned RESOURCE_EXHAUSTED_CAUSE_BUSY_WORKFLOW — the cache + // lock is held while history is being applied. A sign of progress that + // doesn't reset the per-shard no-progress timer. + VerifyReplicationTaskBusy = NewCounterDef("verify_replication_task_busy") + // VerifyReplicationTaskPending counts verify attempts where + // DescribeMutableState succeeded but the workflowVerifier saw the target + // lagging the source. A high pending vs. success ratio means verify is + // polling faster than apply can catch up. + VerifyReplicationTaskPending = NewCounterDef("verify_replication_task_pending") + // ReplicatedWorkflowCount accumulates verified-exec counts across each + // ReplicateBatch activity return. Emitted from the workflow so the + // counter is monotonic across activity retries. + ReplicatedWorkflowCount = NewCounterDef("replicated_workflow_count") + // Replication NamespaceReplicationTaskAckLevelGauge = NewGaugeDef("namespace_replication_task_ack_level") NamespaceReplicationDLQAckLevelGauge = NewGaugeDef("namespace_dlq_ack_level") diff --git a/common/primitives/task_queues.go b/common/primitives/task_queues.go index 0c6c049f11a..0c6dd2dab59 100644 --- a/common/primitives/task_queues.go +++ b/common/primitives/task_queues.go @@ -17,6 +17,7 @@ const ( internalTaskQueuePerNSPrefix = "temporal-sys-per-ns-" MigrationActivityTQ = "temporal-sys-migration-activity-tq" + MigrationShardedActivityTQ = "temporal-sys-migration-sharded-activity-tq" AddSearchAttributesActivityTQ = "temporal-sys-add-search-attributes-activity-tq" DeleteNamespaceActivityTQ = "temporal-sys-delete-namespace-activity-tq" DLQActivityTQ = "temporal-sys-dlq-activity-tq" diff --git a/service/worker/migration/activities.go b/service/worker/migration/activities.go index 22173be166e..e609bf65d32 100644 --- a/service/worker/migration/activities.go +++ b/service/worker/migration/activities.go @@ -31,6 +31,7 @@ import ( "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/quotas" "go.temporal.io/server/common/rpc/interceptor" + "go.temporal.io/server/common/sdk" workercommon "go.temporal.io/server/service/worker/common" "google.golang.org/grpc/metadata" ) @@ -71,12 +72,11 @@ type ( } verifyReplicationTasksRequest struct { - Namespace string - NamespaceID string - TargetClusterEndpoint string - TargetClusterName string - VerifyInterval time.Duration `validate:"gte=0"` - Executions []*ExecutionInfo + Namespace string + NamespaceID string + TargetClusterName string + VerifyInterval time.Duration `validate:"gte=0"` + Executions []*ExecutionInfo } verifyReplicationTasksResponse struct { @@ -92,6 +92,14 @@ type ( NamespaceID string } + DescribeTargetClusterRequest struct { + TargetClusterName string + } + + DescribeTargetClusterResponse struct { + ShardCount int32 + } + ReplicationStatus struct { MaxReplicationTaskIds map[int32]int64 } @@ -129,6 +137,12 @@ type ( enableHistoryRateLimiter dynamicconfig.BoolPropertyFn workflowVerifier WorkflowVerifier chasmRegistry *chasm.Registry + // sdkClientFactory resolves the system SDK client lazily for the + // sharded ReplicateBatch activity's mid-flight ReleaseShards signal. + // Eager resolution at fx-wire time tries to dial the frontend before + // it's listening; the factory's internal sync.Once guarantees a + // single dial on first use. + sdkClientFactory sdk.ClientFactory } shardStatus struct { @@ -187,6 +201,23 @@ func (a *activities) GetMetadata(_ context.Context, request MetadataRequest) (*M }, nil } +// DescribeTargetCluster fetches the remote cluster's history shard count via +// its admin DescribeCluster RPC. The remote must be registered with the +// source cluster's cluster metadata (the cluster name doubles as the +// adminClient cache key) — which is already a prerequisite for force +// replication, since the source generates replication tasks against it. +func (a *activities) DescribeTargetCluster(ctx context.Context, req DescribeTargetClusterRequest) (*DescribeTargetClusterResponse, error) { + remoteAdminClient, err := a.clientBean.GetRemoteAdminClient(req.TargetClusterName) + if err != nil { + return nil, err + } + resp, err := remoteAdminClient.DescribeCluster(ctx, &adminservice.DescribeClusterRequest{}) + if err != nil { + return nil, err + } + return &DescribeTargetClusterResponse{ShardCount: resp.GetHistoryShardCount()}, nil +} + // GetMaxReplicationTaskIDs returns max replication task id per shard func (a *activities) GetMaxReplicationTaskIDs(ctx context.Context) (*ReplicationStatus, error) { ctx = headers.SetCallerInfo(ctx, headers.SystemPreemptableCallerInfo) diff --git a/service/worker/migration/activities_test.go b/service/worker/migration/activities_test.go index 4ca1d7492c9..6bc0e625024 100644 --- a/service/worker/migration/activities_test.go +++ b/service/worker/migration/activities_test.go @@ -111,6 +111,7 @@ func (s *activitiesSuite) SetupTest() { s.mockNamespaceReplicationQueue = persistence.NewMockNamespaceReplicationQueue(s.controller) s.mockNamespaceRegistry = namespace.NewMockRegistry(s.controller) s.mockClientBean = client.NewMockBean(s.controller) + s.mockClientFactory = client.NewMockFactory(s.controller) s.mockFrontendClient = workflowservicemock.NewMockWorkflowServiceClient(s.controller) s.mockAdminClient = adminservicemock.NewMockAdminServiceClient(s.controller) diff --git a/service/worker/migration/force_replication_workflow.go b/service/worker/migration/force_replication_workflow.go index 70a46222692..e7a6899369d 100644 --- a/service/worker/migration/force_replication_workflow.go +++ b/service/worker/migration/force_replication_workflow.go @@ -38,7 +38,6 @@ type ( // Used for verifying workflow executions were replicated successfully on target cluster. EnableVerification bool - TargetClusterEndpoint string TargetClusterName string VerifyIntervalInSeconds int `validate:"gte=0"` @@ -87,6 +86,17 @@ type ( ReplicatedWorkflowCount int64 ReplicatedWorkflowCountPerSecond float64 PageTokenForRestart []byte + + // Sharded-workflow-only recovery bundle: feed these three + // fields back into a fresh ShardedForceReplicationWorkflow's + // NextPageToken / ResumeShards / RecoveredBuckets params to + // resume from a failed run without missing executions. Left + // zero by the legacy ForceReplicationWorkflow variants — + // their PageTokenForRestart is the start-of-run token and + // already covers all in-flight execs at restart cost. + RecoveryNextPageToken []byte + RecoveryResumeShards []ResumeShard + RecoveryBuckets BatchPayload } ) @@ -342,8 +352,8 @@ func validateAndSetForceReplicationParams(ctx workflow.Context, params *ForceRep return temporal.NewNonRetryableApplicationError("InvalidArgument: Namespace is required", "InvalidArgument", nil) } - if params.EnableVerification && len(params.TargetClusterEndpoint) == 0 && len(params.TargetClusterName) == 0 { - return temporal.NewNonRetryableApplicationError("InvalidArgument: TargetClusterEndpoint or TargetClusterName is required with verification enabled", "InvalidArgument", nil) + if params.EnableVerification && len(params.TargetClusterName) == 0 { + return temporal.NewNonRetryableApplicationError("InvalidArgument: TargetClusterName is required with verification enabled", "InvalidArgument", nil) } if params.ConcurrentActivityCount <= 0 { @@ -512,12 +522,11 @@ func enqueueReplicationTasks(ctx workflow.Context, executionsCh workflow.Channel actx, a.VerifyReplicationTasks, &verifyReplicationTasksRequest{ - TargetClusterEndpoint: params.TargetClusterEndpoint, - TargetClusterName: params.TargetClusterName, - Namespace: params.Namespace, - NamespaceID: namespaceID, - Executions: migrationExecutions, - VerifyInterval: time.Duration(params.VerifyIntervalInSeconds) * time.Second, + TargetClusterName: params.TargetClusterName, + Namespace: params.Namespace, + NamespaceID: namespaceID, + Executions: migrationExecutions, + VerifyInterval: time.Duration(params.VerifyIntervalInSeconds) * time.Second, }) pendingVerifyTasks++ @@ -615,12 +624,11 @@ func enqueueReplicationTasksLocal( lactx, a.VerifyReplicationTasks, &verifyReplicationTasksRequest{ - TargetClusterEndpoint: params.TargetClusterEndpoint, - TargetClusterName: params.TargetClusterName, - Namespace: params.Namespace, - NamespaceID: namespaceID, - Executions: executions, - VerifyInterval: time.Duration(params.VerifyIntervalInSeconds) * time.Second, + TargetClusterName: params.TargetClusterName, + Namespace: params.Namespace, + NamespaceID: namespaceID, + Executions: executions, + VerifyInterval: time.Duration(params.VerifyIntervalInSeconds) * time.Second, }) pendingVerifyTasks++ diff --git a/service/worker/migration/force_replication_workflow_test.go b/service/worker/migration/force_replication_workflow_test.go index 8cb9e9e9188..479724c2da8 100644 --- a/service/worker/migration/force_replication_workflow_test.go +++ b/service/worker/migration/force_replication_workflow_test.go @@ -108,7 +108,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestForceReplicationWorkflow() { ListWorkflowsPageSize: 1, PageCountPerExecution: 4, EnableVerification: true, - TargetClusterEndpoint: "test-target", + TargetClusterName: "test-target", }) s.True(env.IsWorkflowCompleted()) @@ -167,8 +167,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestContinueAsNew() { PageCountPerExecution: testMaxPageCountPerExecution, NextPageToken: []byte("fake-page-token-2"), EnableVerification: true, - TargetClusterEndpoint: "test-target", - TargetClusterName: "", + TargetClusterName: "test-target", VerifyIntervalInSeconds: defaultVerifyIntervalInSeconds, LastCloseTime: closeTime, LastStartTime: startTime, @@ -194,7 +193,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestContinueAsNew() { ListWorkflowsPageSize: 1, PageCountPerExecution: testMaxPageCountPerExecution, EnableVerification: true, - TargetClusterEndpoint: "test-target", + TargetClusterName: "test-target", NextPageToken: []byte("fake-initial-page-token"), }, expectContinueAsNew, @@ -295,7 +294,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestInvalidInput() { // Empty namespace }, { - // Empty TargetClusterEndpoint + // Empty TargetClusterName Namespace: uuid.NewString(), EnableVerification: true, }, @@ -438,7 +437,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestGenerateReplicationTaskNonRetrya ListWorkflowsPageSize: 1, PageCountPerExecution: 4, EnableVerification: true, - TargetClusterEndpoint: "test-target", + TargetClusterName: "test-target", }) s.True(env.IsWorkflowCompleted()) @@ -495,7 +494,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestVerifyReplicationTaskNonRetryabl ListWorkflowsPageSize: 1, PageCountPerExecution: 4, EnableVerification: true, - TargetClusterEndpoint: "test-target", + TargetClusterName: "test-target", }) s.True(env.IsWorkflowCompleted()) @@ -705,6 +704,7 @@ type heartbeatRecordingInterceptor struct { seedRecordedHeartbeats []seedReplicationQueueWithUserDataEntriesHeartbeatDetails replicationRecordedHeartbeats []replicationTasksHeartbeatDetails generateReplicationRecordedHeartbeats []int + replicateBatchRecordedHeartbeats []replicateBatchHeartbeat T *testing.T } @@ -725,6 +725,8 @@ func (i *heartbeatRecordingInterceptor) RecordHeartbeat(ctx context.Context, det i.replicationRecordedHeartbeats = append(i.replicationRecordedHeartbeats, d) } else if d, ok := details[0].(int); ok { i.generateReplicationRecordedHeartbeats = append(i.generateReplicationRecordedHeartbeats, d) + } else if d, ok := details[0].(replicateBatchHeartbeat); ok { + i.replicateBatchRecordedHeartbeats = append(i.replicateBatchRecordedHeartbeats, d) } else { assert.Fail(i.T, "invalid heartbeat details") } diff --git a/service/worker/migration/fx.go b/service/worker/migration/fx.go index c1eda546ef3..b92ca6ca85a 100644 --- a/service/worker/migration/fx.go +++ b/service/worker/migration/fx.go @@ -2,13 +2,16 @@ package migration import ( "context" + "fmt" "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/sdk/activity" sdkworker "go.temporal.io/sdk/worker" "go.temporal.io/sdk/workflow" "go.temporal.io/server/api/adminservice/v1" "go.temporal.io/server/chasm" serverClient "go.temporal.io/server/client" + "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/config" "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/headers" @@ -18,6 +21,7 @@ import ( "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/resource" + "go.temporal.io/server/common/sdk" workercommon "go.temporal.io/server/service/worker/common" "go.uber.org/fx" ) @@ -32,6 +36,7 @@ type ( FrontendClient workflowservice.WorkflowServiceClient ClientFactory serverClient.Factory ClientBean serverClient.Bean + ClusterMetadata cluster.Metadata NamespaceReplicationQueue persistence.NamespaceReplicationQueue TaskManager persistence.TaskManager Logger log.Logger @@ -39,6 +44,7 @@ type ( DynamicCollection *dynamicconfig.Collection WorkflowVerifier WorkflowVerifier ChasmRegistry *chasm.Registry + SDKClientFactory sdk.ClientFactory } fxResult struct { @@ -48,21 +54,49 @@ type ( replicationWorkerComponent struct { initParams + activities *activities + } + + // shardedWorkerComponent registers the sharded force-replication + // workflow + ReplicateBatch activity on their dedicated TQ. Holds an + // *activities-sized clone so its activity registration is isolated + // from the default-TQ worker — sharded inject paths don't accidentally + // land on the legacy MigrationActivityTQ. + shardedWorkerComponent struct { + activities *activities } ) var Module = fx.Options( fx.Provide(NewResult), + fx.Provide(NewShardedResult), fx.Provide(workflowVerifierProvider), ) -func NewResult(params initParams) fxResult { - component := &replicationWorkerComponent{ - initParams: params, +func NewResult(params initParams) (fxResult, error) { + a, err := newActivitiesFromParams(params, forceReplicationWorkflowName) + if err != nil { + return fxResult{}, err } return fxResult{ - Component: component, + Component: &replicationWorkerComponent{ + initParams: params, + activities: a, + }, + }, nil +} + +// NewShardedResult constructs the sharded WorkerComponent. The component +// owns its own *activities clone so registration against the sharded TQ +// doesn't bleed into the legacy worker. +func NewShardedResult(params initParams) (fxResult, error) { + a, err := newActivitiesFromParams(params, shardedForceReplicationWorkflowName) + if err != nil { + return fxResult{}, err } + return fxResult{ + Component: &shardedWorkerComponent{activities: a}, + }, nil } func (wc *replicationWorkerComponent) RegisterWorkflow(registry sdkworker.Registry) { @@ -80,7 +114,16 @@ func (wc *replicationWorkerComponent) DedicatedWorkflowWorkerOptions() *workerco } func (wc *replicationWorkerComponent) RegisterActivities(registry sdkworker.Registry) { - registry.RegisterActivity(wc.activities()) + // DisableAlreadyRegisteredCheck because the sharded WorkerComponent + // shares the *activities method set; whichever component registers + // first on the default worker wins (per the worker.go upgrade-hack + // pass), and the second component's reflection-based registration + // would otherwise panic on every method name. The default worker + // isn't dispatched to by either workflow — both have dedicated + // activity workers — so winner-takes-all is fine. + registry.RegisterActivityWithOptions(wc.activities, activity.RegisterOptions{ + DisableAlreadyRegisteredCheck: true, + }) } func (wc *replicationWorkerComponent) DedicatedActivityWorkerOptions() *workercommon.DedicatedWorkerOptions { @@ -92,6 +135,56 @@ func (wc *replicationWorkerComponent) DedicatedActivityWorkerOptions() *workerco } } +func (sc *shardedWorkerComponent) RegisterWorkflow(registry sdkworker.Registry) { + registry.RegisterWorkflowWithOptions(ShardedForceReplicationWorkflow, workflow.RegisterOptions{ + Name: shardedForceReplicationWorkflowName, + }) + registry.RegisterWorkflowWithOptions(ForceTaskQueueUserDataReplicationWorkflow, workflow.RegisterOptions{ + Name: forceTaskQueueUserDataReplicationWorkflow, + }) +} + +func (sc *shardedWorkerComponent) DedicatedWorkflowWorkerOptions() *workercommon.DedicatedWorkerOptions { + // Workflow + activity share the same TQ so the workflow's default + // ExecuteActivity (no explicit TaskQueue) routes to our dedicated + // activity worker rather than the default-TQ worker. Without a + // dedicated workflow worker here the workflow would land on + // default-worker-tq and its activities would pile up on the (separate, + // our-TQ) dedicated activity worker, unscheduled. + // + // LocalActivityWorkerOnly is essential: by default a worker polls for + // both workflow and activity tasks on its TQ. Since the activity + // worker (a separate sdkworker.Worker) also polls this TQ and is the + // one that owns the registered activities, leaving activity polling + // enabled here means this worker races for activity tasks and + // dispatches them with no registrations — ActivityNotRegisteredError, + // "Supported types: []". + return &workercommon.DedicatedWorkerOptions{ + TaskQueue: primitives.MigrationShardedActivityTQ, + Options: sdkworker.Options{ + LocalActivityWorkerOnly: true, + }, + } +} + +func (sc *shardedWorkerComponent) RegisterActivities(registry sdkworker.Registry) { + // See replicationWorkerComponent.RegisterActivities — both components + // share the *activities method set, so the second registration on the + // default worker would otherwise panic. + registry.RegisterActivityWithOptions(sc.activities, activity.RegisterOptions{ + DisableAlreadyRegisteredCheck: true, + }) +} + +func (sc *shardedWorkerComponent) DedicatedActivityWorkerOptions() *workercommon.DedicatedWorkerOptions { + return &workercommon.DedicatedWorkerOptions{ + TaskQueue: primitives.MigrationShardedActivityTQ, + Options: sdkworker.Options{ + BackgroundActivityContext: headers.SetCallerType(context.Background(), headers.CallerTypePreemptable), + }, + } +} + func workflowVerifierProvider() WorkflowVerifier { return func( _ context.Context, @@ -108,23 +201,42 @@ func workflowVerifierProvider() WorkflowVerifier { } } -func (wc *replicationWorkerComponent) activities() *activities { - return &activities{ - HistoryShardCount: wc.PersistenceConfig.NumHistoryShards, - executionManager: wc.ExecutionManager, - NamespaceRegistry: wc.NamespaceRegistry, - HistoryClient: wc.HistoryClient, - frontendClient: wc.FrontendClient, - clientFactory: wc.ClientFactory, - clientBean: wc.ClientBean, - namespaceReplicationQueue: wc.NamespaceReplicationQueue, - taskManager: wc.TaskManager, - Logger: wc.Logger, - MetricsHandler: wc.MetricsHandler, - forceReplicationMetricsHandler: wc.MetricsHandler.WithTags(metrics.WorkflowTypeTag(forceReplicationWorkflowName)), - generateMigrationTaskViaFrontend: dynamicconfig.WorkerGenerateMigrationTaskViaFrontend.Get(wc.DynamicCollection), - enableHistoryRateLimiter: dynamicconfig.WorkerEnableHistoryRateLimiter.Get(wc.DynamicCollection), - workflowVerifier: wc.WorkflowVerifier, - chasmRegistry: wc.ChasmRegistry, +// newActivitiesFromParams builds the shared *activities struct from the +// fx params. workflowTypeName tags the forceReplicationMetricsHandler so +// the legacy and sharded variants emit force-replication metrics under +// distinct workflow_type tags. +// +// adminClient is the local admin client cached by ClientBean at startup. +// Routing through the bean (rather than constructing a fresh wrapper via +// NewLocalAdminClientWithTimeout) reuses the same retry+metric wrapper +// every other consumer in the process sees, and guarantees adminClient +// is non-nil so the inject and verify paths can use it without nil +// guarding. A lookup failure indicates ClusterMetadata is misconfigured; +// surfacing it as an fx error fails app start cleanly rather than mid-run. +func newActivitiesFromParams(params initParams, workflowTypeName string) (*activities, error) { + localCluster := params.ClusterMetadata.GetCurrentClusterName() + localAdmin, err := params.ClientBean.GetRemoteAdminClient(localCluster) + if err != nil { + return nil, fmt.Errorf("migration: local admin client missing from ClientBean for cluster %q: %w", localCluster, err) } + return &activities{ + HistoryShardCount: params.PersistenceConfig.NumHistoryShards, + executionManager: params.ExecutionManager, + NamespaceRegistry: params.NamespaceRegistry, + HistoryClient: params.HistoryClient, + frontendClient: params.FrontendClient, + adminClient: localAdmin, + clientFactory: params.ClientFactory, + clientBean: params.ClientBean, + namespaceReplicationQueue: params.NamespaceReplicationQueue, + taskManager: params.TaskManager, + Logger: params.Logger, + MetricsHandler: params.MetricsHandler, + forceReplicationMetricsHandler: params.MetricsHandler.WithTags(metrics.WorkflowTypeTag(workflowTypeName)), + generateMigrationTaskViaFrontend: dynamicconfig.WorkerGenerateMigrationTaskViaFrontend.Get(params.DynamicCollection), + enableHistoryRateLimiter: dynamicconfig.WorkerEnableHistoryRateLimiter.Get(params.DynamicCollection), + workflowVerifier: params.WorkflowVerifier, + chasmRegistry: params.ChasmRegistry, + sdkClientFactory: params.SDKClientFactory, + }, nil } diff --git a/service/worker/migration/sharded_activity.go b/service/worker/migration/sharded_activity.go new file mode 100644 index 00000000000..336bfa60db1 --- /dev/null +++ b/service/worker/migration/sharded_activity.go @@ -0,0 +1,791 @@ +package migration + +import ( + "context" + "errors" + "fmt" + "math" + "slices" + "time" + + commonpb "go.temporal.io/api/common/v1" + enumspb "go.temporal.io/api/enums/v1" + "go.temporal.io/api/serviceerror" + "go.temporal.io/sdk/activity" + "go.temporal.io/sdk/temporal" + "go.temporal.io/server/api/adminservice/v1" + "go.temporal.io/server/common" + "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/common/metrics" + "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/quotas" +) + +// ReplicateBatch is the per-batch activity body for the sharded force +// replication workflow. Runs inject (skipped on Resume) then verify, +// signal-releasing completed shards mid-flight as their cumulative +// idle cost crosses IdleShardCost. On workflow-initiated cancellation +// it enters drain mode and returns a replicateBatchResult carrying +// any still-unverified execs. Drain-mode signal traffic is suppressed: +// once we know we're about to return, there's no point racing a +// signal against the return value. +func (a *activities) ReplicateBatch(ctx context.Context, req *shardedBatchReq) (replicateBatchResult, error) { + // Flatten once so per-exec bookkeeping (verified[], attempts[], + // nextRetryAt[]) can stay index-based. + execs := req.Executions.flatten() + execCount := len(execs) + if execCount == 0 { + return replicateBatchResult{}, nil + } + + remoteAdminClient, err := a.clientBean.GetRemoteAdminClient(req.TargetClusterName) + if err != nil { + return replicateBatchResult{}, fmt.Errorf("get remote admin client for %s: %w", req.TargetClusterName, err) + } + + var hb replicateBatchHeartbeat + if activity.HasHeartbeatDetails(ctx) { + _ = activity.GetHeartbeatDetails(ctx, &hb) + } + + // ---- Inject phase ---- + if !req.Resume && !hb.InjectDone { + startIdx := hb.NextInjectIdx + if err := a.runInjectPhase(ctx, req, execs, startIdx); err != nil { + return replicateBatchResult{}, err + } + activity.RecordHeartbeat(ctx, replicateBatchHeartbeat{InjectDone: true}) + } + + // inject-only path: skip verify and let the workflow accounting + // release shards on activity return. + if req.DisableVerification { + return replicateBatchResult{ + CompletedShards: req.Executions.sortedShards(), + VerifiedCount: 0, + }, nil + } + + // Namespace lookup feeds the verify phase's retention/zombie skip + // check (checkSkipWorkflowExecution needs ns.Retention()). Snapshotted + // once per activity; we don't track config changes mid-batch. + ns, err := a.NamespaceRegistry.GetNamespaceByID(namespace.ID(req.NamespaceID)) + if err != nil { + return replicateBatchResult{}, fmt.Errorf("look up namespace %s: %w", req.NamespaceID, err) + } + + return a.runVerifyPhase(ctx, req, execs, execCount, remoteAdminClient, ns) +} + +// runVerifyPhase is the verify-phase loop body of ReplicateBatch. It +// owns per-exec bookkeeping, the drain-transition handoff, and the +// per-iteration completion / stuck-shard / signal-release decisions +// — extracted from ReplicateBatch to keep its cognitive complexity +// under the linter cap. +func (a *activities) runVerifyPhase( + ctx context.Context, + req *shardedBatchReq, + execs []*shardedExecutionInfo, + execCount int, + remoteAdminClient adminservice.AdminServiceClient, + ns *namespace.Namespace, +) (replicateBatchResult, error) { + verified := make([]bool, execCount) + attempts := make([]int, execCount) + nextRetryAt := make([]time.Time, execCount) + doneCount := 0 + + shards := newShardVerifyTracker(execs, req.Resume, req.NoProgressByShard) + + var draining bool + var drainStartAt time.Time + + // callCtx is what attemptVerifyExec uses for DescribeMutableState. + // In drain mode we swap to a detached context: the parent ctx is + // already dead (that's what triggered the transition), so reusing + // it would make every drain-mode RPC fail instantly. + callCtx := ctx + // Pre-create the drain context up front and defer cancel right away + // so go vet's lostcancel pass sees the canonical pattern. + drainCtx, drainCancel := context.WithCancel(context.Background()) + defer drainCancel() + + for { + // Worker shutdown short-circuits drain mode entirely. The SDK + // closes WorkerStopChannel WorkerStopTimeout before forcibly + // returning; burning that window on DMS calls that can't drive + // their results back is worse than returning current state and + // letting ResumeShards / RecoveredBuckets recover next cycle. + // Re-checked each iteration because shutdown can fire after + // we've already entered drain — the detached ctx wouldn't + // notice on its own. + select { + case <-activity.GetWorkerStopChannel(ctx): + return replicateBatchResult{ + CompletedShards: shards.allCompleted(), + InFlight: buildInFlight(execs, verified, shards, time.Now()), + VerifiedCount: int64(doneCount), + }, nil + default: + } + + // Workflow-initiated activity cancellation (drainForCAN) + // transitions us into drain mode with the full DrainGrace window. + // WaitForCancellation=true on the activity options guarantees the + // workflow blocks for us, so the grace window is genuinely + // available — swap callCtx onto a detached deadline so + // DescribeMutableState keeps working after the parent ctx died. + if !draining && ctx.Err() != nil { + draining = true + drainStartAt = time.Now() + // Start the drain budget timer here rather than at activity + // entry so the grace window measures from drain transition, + // not from activity start. + time.AfterFunc(req.DrainGrace, drainCancel) + callCtx = drainCtx + } + + passDelta, minNextRetry, ctxAborted, vErr := a.runVerifyPass( + ctx, callCtx, remoteAdminClient, ns, req, execs, verified, attempts, nextRetryAt, shards) + // Fold partial progress in before the error check — the SDK + // discards the activity result on failure, so the only way + // the workflow learns about partially-verified execs on the + // error path is via wrapBatchVerifyError encoding the count + // as ApplicationError details below. + doneCount += passDelta + if vErr != nil { + return replicateBatchResult{}, wrapBatchVerifyError(vErr, int64(doneCount)) + } + + activity.RecordHeartbeat(ctx, replicateBatchHeartbeat{InjectDone: true}) + + if done, result, err := a.evaluateVerifyIteration( + ctx, req, execs, verified, shards, doneCount, execCount, draining, drainStartAt); err != nil { + return replicateBatchResult{}, wrapBatchVerifyError(err, int64(doneCount)) + } else if done { + return result, nil + } + + // If the inner loop aborted because callCtx died, skip the sleep + // entirely so the outer-loop top sees the new state promptly + // (normal → drain transition, or drain → exit). + if ctxAborted { + continue + } + + waitNextTick(ctx, callCtx, minNextRetry, draining, drainStartAt, req.DrainGrace) + } +} + +// evaluateVerifyIteration runs the post-pass checks (clean completion, +// stuck-shard backstop, drain-exit decision, mid-flight signal release) +// for a single verify-loop iteration. Returns done=true with the result +// when the loop should exit; otherwise (false, _, nil) means continue. +func (a *activities) evaluateVerifyIteration( + ctx context.Context, + req *shardedBatchReq, + execs []*shardedExecutionInfo, + verified []bool, + shards shardVerifyTracker, + doneCount, execCount int, + draining bool, + drainStartAt time.Time, +) (bool, replicateBatchResult, error) { + // Clean completion — every exec verified. + if doneCount >= execCount { + return true, replicateBatchResult{ + CompletedShards: shards.allCompleted(), + VerifiedCount: int64(doneCount), + }, nil + } + + if draining { + // Drain-mode exit checks. No signals here — the return value + // carries everything the workflow needs (completed shards + + // unverified execs grouped by shard with their cumulative + // no-progress duration). The per-shard no-progress backstop is + // deliberately skipped: drain is bounded by DrainGrace and the + // outstanding execs need to flow back via InFlight for CAN + // carry-over, not surface as a ShardNoProgress failure. + if shouldExitDrain(req, shards, drainStartAt) { + return true, replicateBatchResult{ + CompletedShards: shards.allCompleted(), + InFlight: buildInFlight(execs, verified, shards, time.Now()), + VerifiedCount: int64(doneCount), + }, nil + } + return false, replicateBatchResult{}, nil + } + + // Per-shard cumulative no-progress backstop. + if sErr := a.checkStuckShard(req, shards, execs, verified, doneCount, execCount); sErr != nil { + return false, replicateBatchResult{}, sErr + } + + if err := a.maybeSignalRelease(ctx, req, shards); err != nil { + return false, replicateBatchResult{}, err + } + return false, replicateBatchResult{}, nil +} + +// runInjectPhase walks execs in flattened order, generating one +// replication task per exec under a per-batch RPS limiter. Cancellation +// mid-loop returns a CanceledError so spawnBatch's IsCanceledError +// check fires and the batch's batchExecs entry is preserved for +// RecoveredBuckets re-injection next cycle. Already-injected execs +// are re-injected harmlessly — replication dedupes per (namespace, +// wf, run). +func (a *activities) runInjectPhase(ctx context.Context, req *shardedBatchReq, execs []*shardedExecutionInfo, startIdx int) error { + rateLimiter := quotas.NewRateLimiter(req.PerBatchGenerateRPS, int(math.Ceil(req.PerBatchGenerateRPS))) + for i := startIdx; i < len(execs); i++ { + ex := execs[i] + if ctx.Err() != nil { + return temporal.NewCanceledError("inject phase cancelled") + } + if err := a.generateReplicationTaskForExec(ctx, rateLimiter, req, ex); err != nil { + if ctx.Err() != nil { + return temporal.NewCanceledError("inject phase cancelled") + } + if !common.IsNotFoundError(err) { + return err + } + a.Logger.Warn("force-replication-sharded ignore replication task due to NotFoundServiceError", + tag.WorkflowNamespaceID(req.NamespaceID), + tag.WorkflowID(ex.BusinessID), + tag.WorkflowRunID(ex.RunID), + tag.Error(err)) + } + activity.RecordHeartbeat(ctx, replicateBatchHeartbeat{NextInjectIdx: i + 1}) + } + return nil +} + +// runVerifyPass runs one pass over every unverified exec, attempting a +// verify on those whose backoff timer has expired. Returns the count of +// execs newly verified this pass, the earliest pending retry deadline +// (for sleep scheduling), and whether callCtx died mid-pass — in which +// case the outer loop's top reassesses (drain transition or exit). +// Returns a non-nil error only for hard errors from the verify path; +// ctx-derived errors set ctxAborted instead so the outer loop owns +// the decision about what to do next. +// +// ctx is the activity ctx, used only for heartbeating — a single pass +// over a large batch can outlast HeartbeatTimeout if we only heartbeat +// once at the end, so we tick per attempted exec. callCtx is what the +// DMS call rides on (the detached drain ctx in drain mode). +func (a *activities) runVerifyPass( + ctx context.Context, + callCtx context.Context, + remoteAdminClient adminservice.AdminServiceClient, + ns *namespace.Namespace, + req *shardedBatchReq, + execs []*shardedExecutionInfo, + verified []bool, + attempts []int, + nextRetryAt []time.Time, + shards shardVerifyTracker, +) (int, time.Time, bool, error) { + now := time.Now() + var minNextRetry time.Time + verifiedDelta := 0 + for i, ex := range execs { + if verified[i] { + continue + } + if !nextRetryAt[i].IsZero() && nextRetryAt[i].After(now) { + minNextRetry = earliest(minNextRetry, nextRetryAt[i]) + continue + } + + ok, err := a.attemptVerifyExec(callCtx, remoteAdminClient, ns, req, ex) + if err != nil { + if callCtx.Err() != nil { + // callCtx is dead. Two cases: (1) normal-mode parent ctx + // was just cancelled mid-call — the outer-loop top will + // promote to drain on the next iteration; (2) drain-mode + // detached ctx expired — the drain exit check fires. + return verifiedDelta, minNextRetry, true, nil + } + return verifiedDelta, minNextRetry, false, err + } + + if ok { + verified[i] = true + verifiedDelta++ + shards.recordVerified(ex.Shard, time.Now()) + } else { + attempts[i]++ + nextRetryAt[i] = time.Now().Add(backoffDelay(attempts[i])) + minNextRetry = earliest(minNextRetry, nextRetryAt[i]) + } + + activity.RecordHeartbeat(ctx, replicateBatchHeartbeat{InjectDone: true}) + } + return verifiedDelta, minNextRetry, false, nil +} + +// earliest returns the earlier of cur (which may be zero) and candidate. +// Used to track the next-due retry deadline across the verify pass. +func earliest(cur, candidate time.Time) time.Time { + if cur.IsZero() || candidate.Before(cur) { + return candidate + } + return cur +} + +// checkStuckShard fails non-retryably if any shard has gone longer than +// req.ShardNoProgress without a verified outcome. Duration is cumulative +// across CAN cycles via tracker seeding. +func (a *activities) checkStuckShard( + req *shardedBatchReq, + shards shardVerifyTracker, + execs []*shardedExecutionInfo, + verified []bool, + doneCount, total int, +) error { + stuckShard, stuckDur, ok := shards.pickStuck(time.Now(), req.ShardNoProgress) + if !ok { + return nil + } + msg := fmt.Sprintf("shard %d no progress for %v", stuckShard, stuckDur) + if stuckIdx, found := firstUnverifiedOnShard(execs, verified, stuckShard); found { + stuck := execs[stuckIdx] + msg = fmt.Sprintf("shard %d no progress for %v on %s/%s (%d/%d done)", + stuckShard, stuckDur, stuck.BusinessID, stuck.RunID, doneCount, total) + } + return temporal.NewNonRetryableApplicationError(msg, "ShardNoProgress", nil) +} + +// shouldExitDrain reports whether the drain-mode exit conditions are +// met: either the grace window has expired, or the cumulative idle cost +// across completed-but-unsignaled shards crossed the threshold. +func shouldExitDrain(req *shardedBatchReq, shards shardVerifyTracker, drainStartAt time.Time) bool { + if time.Since(drainStartAt) >= req.DrainGrace { + return true + } + return shards.totalIdleCost(time.Now()) >= req.IdleShardCost +} + +// maybeSignalRelease signals the workflow to release any +// completed-but-unsignaled shards if their cumulative idle cost crossed +// the threshold. Only fires in normal mode — drain mode rides the +// activity result instead. +// +// Ctx-canceled errors from signalReleaseShards are suppressed: a +// workflow-initiated cancel arriving mid-signal would otherwise +// surface as a wrapped ctx-canceled error (not temporal.CanceledError) +// that the workflow side wouldn't recognise via IsCanceledError — +// turning a clean CAN into an error exit. Suppressing here lets the +// outer loop see ctx.Err() at its top and promote to drain mode +// normally. +func (a *activities) maybeSignalRelease(ctx context.Context, req *shardedBatchReq, shards shardVerifyTracker) error { + if shards.totalIdleCost(time.Now()) < req.IdleShardCost { + return nil + } + releaseList := shards.awaitingRelease() + if len(releaseList) == 0 { + return nil + } + if err := a.signalReleaseShards(ctx, req, releaseList); err != nil { + if ctx.Err() != nil { + return nil + } + return err + } + shards.markReleased(releaseList) + return nil +} + +// waitNextTick sleeps until the next exec is due for retry, capped by +// DrainGrace remaining when in drain mode. In drain mode the parent ctx +// is already dead, so we wake on the detached drain ctx instead — using +// the parent ctx would tight-loop on its Done channel. +func waitNextTick( + ctx, callCtx context.Context, + minNextRetry time.Time, + draining bool, + drainStartAt time.Time, + drainGrace time.Duration, +) { + sleepDur := 50 * time.Millisecond + if !minNextRetry.IsZero() { + if delta := time.Until(minNextRetry); delta > sleepDur { + sleepDur = delta + } + } + if draining { + if remaining := drainGrace - time.Since(drainStartAt); remaining > 0 && remaining < sleepDur { + sleepDur = remaining + } + select { + case <-time.After(sleepDur): + case <-callCtx.Done(): + } + return + } + select { + case <-time.After(sleepDur): + case <-ctx.Done(): + // ctx cancel just sets draining on the next iteration; don't + // unwind here. + } +} + +// generateReplicationTaskForExec is the per-exec inject wrapper around +// generateWorkflowReplicationTask; supplies the sharded-only +// single-target-clusters slice and the generateViaFrontend flag. +func (a *activities) generateReplicationTaskForExec( + ctx context.Context, + rateLimiter quotas.RateLimiter, + req *shardedBatchReq, + ex *shardedExecutionInfo, +) error { + start := time.Now() + defer func() { + a.forceReplicationMetricsHandler.WithTags(metrics.NamespaceTag(req.Namespace)). + Timer(metrics.GenerateReplicationTaskLatency.Name()).Record(time.Since(start)) + }() + return a.generateWorkflowReplicationTask( + ctx, + rateLimiter, + req.Namespace, + req.NamespaceID, + ex.ExecutionInfo, + []string{req.TargetClusterName}, + a.generateMigrationTaskViaFrontend(), + ) +} + +// attemptVerifyExec runs the source-describe + target-applied check for +// a single execution and returns whether it's now verified. +// +// Why this isn't a delegation to verifySingleReplicationTask: that +// helper folds BUSY_WORKFLOW into the generic notVerified result. We +// inline the DMS call here to keep busy-workflow as a distinct metric +// counter, preserving the "passive cluster apply is in progress" +// signal. +func (a *activities) attemptVerifyExec( + ctx context.Context, + remoteAdminClient adminservice.AdminServiceClient, + ns *namespace.Namespace, + req *shardedBatchReq, + ex *shardedExecutionInfo, +) (bool, error) { + attemptStart := time.Now() + defer func() { + a.forceReplicationMetricsHandler.WithTags(metrics.NamespaceTag(req.Namespace)). + Timer(metrics.VerifyReplicationTaskLatency.Name()).Record(time.Since(attemptStart)) + }() + + archetype, err := a.archetypeIDToName(ctx, ex.ArchetypeID) + if err != nil { + return false, err + } + + vreq := &verifyReplicationTasksRequest{ + Namespace: req.Namespace, + NamespaceID: req.NamespaceID, + TargetClusterName: req.TargetClusterName, + } + + describeStart := time.Now() + mu, err := remoteAdminClient.DescribeMutableState(ctx, &adminservice.DescribeMutableStateRequest{ + Namespace: req.Namespace, + Execution: &commonpb.WorkflowExecution{ + WorkflowId: ex.BusinessID, + RunId: ex.RunID, + }, + Archetype: archetype, + ArchetypeId: ex.ArchetypeID, + SkipForceReload: true, + }) + a.forceReplicationMetricsHandler.Timer(metrics.VerifyDescribeMutableStateLatency.Name()).Record(time.Since(describeStart)) + + nsTag := metrics.NamespaceTag(req.Namespace) + + if err == nil { + result, vErr := a.workflowVerifier(ctx, vreq, remoteAdminClient, a.adminClient, ns, ex.ExecutionInfo, mu) + if vErr != nil { + return false, vErr + } + if result.isVerified() { + a.forceReplicationMetricsHandler.WithTags(nsTag).Counter(metrics.VerifyReplicationTaskSuccess.Name()).Record(1) + return true, nil + } + a.forceReplicationMetricsHandler.WithTags(nsTag).Counter(metrics.VerifyReplicationTaskPending.Name()).Record(1) + return false, nil + } + + if _, ok := errors.AsType[*serviceerror.NotFound](err); ok { + a.forceReplicationMetricsHandler.WithTags(nsTag).Counter(metrics.VerifyReplicationTaskNotFound.Name()).Record(1) + // Retention/zombie path: a not-found execution may already be + // deleted on source (zombie or past retention), in which case it + // never needs to replicate — treat that as verified so the + // shard's completion accounting moves forward. + result, sErr := a.checkSkipWorkflowExecution(ctx, vreq, ex.ExecutionInfo, ns) + if sErr != nil { + return false, sErr + } + return result.isVerified(), nil + } + + if _, ok := errors.AsType[*serviceerror.NamespaceNotFound](err); ok { + return false, temporal.NewNonRetryableApplicationError( + "failed to describe workflow from the remote cluster", "NamespaceNotFound", err) + } + + if resExhausted, ok := errors.AsType[*serviceerror.ResourceExhausted](err); ok && resExhausted.Cause == enumspb.RESOURCE_EXHAUSTED_CAUSE_BUSY_WORKFLOW { + // Passive cluster holds the workflow cache lock while applying + // history during SyncWorkflowStateTask. Counted separately from + // pending so the "apply is in progress" signal stays visible, + // but the workflow-side treatment matches pending — per-exec + // backoff applies and the per-shard last-progress timer does + // not move (it only updates on verified outcomes). + a.forceReplicationMetricsHandler.WithTags(nsTag).Counter(metrics.VerifyReplicationTaskBusy.Name()).Record(1) + return false, nil + } + + a.forceReplicationMetricsHandler.WithTags(nsTag, metrics.ServiceErrorTypeTag(err)). + Counter(metrics.VerifyReplicationTaskFailed.Name()).Record(1) + return false, fmt.Errorf("describe workflow on remote cluster: %w", err) +} + +// signalReleaseShards sends the mid-flight ReleaseShards signal to the +// parent workflow, freeing the listed shards in the workflow's +// shardInFlight set so the packer can dispatch fresh batches against them +// while this activity stays running on its still-pending shards. +// +// No retry wrapping: a transient failure propagates up so the activity +// fails, the workflow records it via lastErr, and the in-flight batch is +// recovered into the next CAN's RecoveredBuckets — preferable to silently +// swallowing the error here and stranding completed shards. +func (a *activities) signalReleaseShards(ctx context.Context, req *shardedBatchReq, shards []int32) error { + info := activity.GetInfo(ctx) + return a.sdkClientFactory.GetSystemClient().SignalWorkflow(ctx, info.WorkflowExecution.ID, info.WorkflowExecution.RunID, releaseShardsSignalName, releaseShardsPayload{ + BatchID: req.BatchID, + Shards: shards, + }) +} + +// shardVerify holds per-shard verify-phase state for one batch. +type shardVerify struct { + pending int + doneAt time.Time // set when pending first hits zero; cleared on signal release + released bool // ReleaseShards signal already sent + lastProgress time.Time // wall time of the most recent verified outcome +} + +type shardVerifyTracker map[int32]shardVerify + +func newShardVerifyTracker( + execs []*shardedExecutionInfo, + resume bool, + noProgressByShard map[int32]time.Duration, +) shardVerifyTracker { + t := shardVerifyTracker{} + for _, ex := range execs { + sv := t[ex.Shard] + sv.pending++ + t[ex.Shard] = sv + } + nowSeed := time.Now() + for sh, sv := range t { + if resume { + sv.lastProgress = nowSeed.Add(-noProgressByShard[sh]) + } else { + sv.lastProgress = nowSeed + } + t[sh] = sv + } + return t +} + +func (t shardVerifyTracker) recordVerified(sh int32, now time.Time) { + sv := t[sh] + sv.pending-- + sv.lastProgress = now + if sv.pending == 0 { + sv.doneAt = now + } + t[sh] = sv +} + +func (t shardVerifyTracker) markReleased(shards []int32) { + for _, sh := range shards { + sv := t[sh] + sv.released = true + sv.doneAt = time.Time{} + t[sh] = sv + } +} + +// totalIdleCost sums idle time across shards that are completed but not +// yet released — the "shard-seconds" unit the IdleShardCost threshold is +// denominated in. +func (t shardVerifyTracker) totalIdleCost(now time.Time) time.Duration { + var total time.Duration + for _, sv := range t { + if !sv.doneAt.IsZero() { + total += now.Sub(sv.doneAt) + } + } + return total +} + +// awaitingRelease returns completed-but-not-yet-signaled shard IDs in +// ascending order so the signal payload is deterministic across replays. +func (t shardVerifyTracker) awaitingRelease() []int32 { + var out []int32 + for sh, sv := range t { + if !sv.doneAt.IsZero() { + out = append(out, sh) + } + } + slices.Sort(out) + return out +} + +// allCompleted returns every shard that finished during this activity +// run — both signal-released and still awaiting release at return. +func (t shardVerifyTracker) allCompleted() []int32 { + var out []int32 + for sh, sv := range t { + if sv.released || !sv.doneAt.IsZero() { + out = append(out, sh) + } + } + slices.Sort(out) + return out +} + +// pickStuck returns (shard, age, true) for the lowest-numbered shard +// whose cumulative no-progress duration meets or exceeds threshold. +func (t shardVerifyTracker) pickStuck(now time.Time, threshold time.Duration) (int32, time.Duration, bool) { + var ( + minShard int32 + minAge time.Duration + found bool + ) + for sh, sv := range t { + if sv.pending <= 0 { + continue + } + age := now.Sub(sv.lastProgress) + if age < threshold { + continue + } + if !found || sh < minShard { + minShard = sh + minAge = age + found = true + } + } + return minShard, minAge, found +} + +// buildInFlight groups unverified execs by shard then businessID and +// attaches the cumulative no-progress duration per shard, for the +// drain-mode activity return. Shards with zero unverified execs are +// reported via CompletedShards instead. +func buildInFlight( + execs []*shardedExecutionInfo, + verified []bool, + shards shardVerifyTracker, + now time.Time, +) []ResumeShard { + byShard := map[int32]map[string][]RunEntry{} + for i, ex := range execs { + if verified[i] { + continue + } + if byShard[ex.Shard] == nil { + byShard[ex.Shard] = map[string][]RunEntry{} + } + byShard[ex.Shard][ex.BusinessID] = append(byShard[ex.Shard][ex.BusinessID], RunEntry{ + RunID: ex.RunID, + ArchetypeID: ex.ArchetypeID, + }) + } + if len(byShard) == 0 { + return nil + } + shardIDs := make([]int32, 0, len(byShard)) + for sh := range byShard { + shardIDs = append(shardIDs, sh) + } + slices.Sort(shardIDs) + out := make([]ResumeShard, 0, len(shardIDs)) + for _, sh := range shardIDs { + out = append(out, ResumeShard{ + Shard: sh, + Execs: byShard[sh], + NoProgressDuration: now.Sub(shards[sh].lastProgress), + }) + } + return out +} + +// firstUnverifiedOnShard returns the index of the first execution in the +// flattened execs slice that targets the given shard and hasn't verified +// yet, and a found flag. Callers should only invoke this for shards +// with at least one pending exec; the found=false return is a defensive +// fallback so a tracker / verified-slice drift can't crash the activity. +func firstUnverifiedOnShard(execs []*shardedExecutionInfo, verified []bool, shard int32) (int, bool) { + for i, ex := range execs { + if verified[i] { + continue + } + if ex.Shard == shard { + return i, true + } + } + return 0, false +} + +// batchVerifyPartialErrorType is the ApplicationError Type stamped on +// wrappers produced by wrapBatchVerifyError. The workflow keys off +// this Type via extractVerifiedCountFromError to disambiguate "the +// wrapper we made" from any other ApplicationError carrying an +// int64. The original error is reachable via Unwrap on the wrapper. +const batchVerifyPartialErrorType = "BatchVerifyPartial" + +// wrapBatchVerifyError wraps the verify-phase error so the partial +// VerifiedCount survives the activity boundary — the SDK discards +// the activity result on failure, so the count would otherwise be +// lost. The original error is attached as Cause; the workflow side +// reaches it via errors.As / Unwrap as usual. Returns the cause +// unchanged when there's no progress to report. +func wrapBatchVerifyError(cause error, verifiedCount int64) error { + if cause == nil || verifiedCount <= 0 { + return cause + } + nonRetryable := false + if appErr, ok := errors.AsType[*temporal.ApplicationError](cause); ok { + nonRetryable = appErr.NonRetryable() + } + return temporal.NewApplicationErrorWithOptions( + cause.Error(), + batchVerifyPartialErrorType, + temporal.ApplicationErrorOptions{ + Cause: cause, + Details: []any{verifiedCount}, + NonRetryable: nonRetryable, + }, + ) +} + +// backoffDelay returns the per-exec retry delay after `attempt` +// consecutive failed verify attempts: 100ms × 2^(attempt-1), capped at +// 5s. The cap bounds how long after the apply pipeline recovers we'd +// take to notice; the per-shard no-progress timer fires after enough of +// these capped retries to distinguish "actively checking" from "lazy +// polling gave up". +func backoffDelay(attempt int) time.Duration { + if attempt < 1 { + attempt = 1 + } + if attempt > 6 { + return 5 * time.Second + } + return 100 * time.Millisecond * (1 << (attempt - 1)) +} diff --git a/service/worker/migration/sharded_activity_test.go b/service/worker/migration/sharded_activity_test.go new file mode 100644 index 00000000000..05110158755 --- /dev/null +++ b/service/worker/migration/sharded_activity_test.go @@ -0,0 +1,346 @@ +package migration + +import ( + "time" + + commonpb "go.temporal.io/api/common/v1" + enumspb "go.temporal.io/api/enums/v1" + "go.temporal.io/api/serviceerror" + "go.temporal.io/sdk/temporal" + "go.temporal.io/server/api/adminservice/v1" + enumsspb "go.temporal.io/server/api/enums/v1" + "go.temporal.io/server/api/historyservice/v1" + persistencespb "go.temporal.io/server/api/persistence/v1" + "go.temporal.io/server/chasm" + "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/testing/protomock" + "go.uber.org/mock/gomock" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// Sharded-activity tests reuse activitiesSuite's SetupTest so they get the +// same mock graph (HistoryClient, AdminClient, ChasmRegistry, etc.) as the +// legacy force-replication activity tests. ReplicateBatch resolves the +// remote admin client via clientBean.GetRemoteAdminClient, which the suite +// already arms with mockRemoteAdminClient — no per-test setup needed. + +// payloadFor wraps a single ExecutionInfo into a BatchPayload on the named +// shard. Tests that need multiple execs across shards build the BatchPayload +// inline. +func payloadFor(shard int32, ex *ExecutionInfo) BatchPayload { + return BatchPayload{ + shard: {ex.BusinessID: {{RunID: ex.RunID, ArchetypeID: ex.ArchetypeID}}}, + } +} + +// newShardedReq builds a shardedBatchReq with sensible defaults for unit +// tests. ShardNoProgress is large so the stuck-shard backstop doesn't trip +// from real wall-clock latency; IdleShardCost is large so maybeSignalRelease +// (which needs the sdkClientFactory, nil here) doesn't fire. +func newShardedReq(execs BatchPayload) *shardedBatchReq { + return &shardedBatchReq{ + BatchID: 1, + Namespace: mockedNamespace, + NamespaceID: mockedNamespaceID, + Executions: execs, + TargetClusterName: remoteCluster, + PerBatchGenerateRPS: defaultPerBatchGenerateRPS, + ShardNoProgress: time.Hour, + DrainGrace: time.Second, + IdleShardCost: time.Hour, + } +} + +// expectRemoteNotFound primes the remote admin client to return NotFound +// for the given exec — the trigger for the verify-skip code path that +// consults source DMS to decide between zombie/retention skip vs. real +// pending state. +func (s *activitiesSuite) expectRemoteNotFound(ex *ExecutionInfo) { + s.mockRemoteAdminClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&adminservice.DescribeMutableStateRequest{ + Namespace: mockedNamespace, + Execution: &commonpb.WorkflowExecution{ + WorkflowId: ex.BusinessID, + RunId: ex.RunID, + }, + Archetype: chasm.WorkflowArchetype, + ArchetypeId: ex.ArchetypeID, + SkipForceReload: true, + })).Return(nil, serviceerror.NewNotFound("")).Times(1) +} + +func (s *activitiesSuite) expectSourceDMS(ex *ExecutionInfo, resp *historyservice.DescribeMutableStateResponse, err error) { + s.mockHistoryClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&historyservice.DescribeMutableStateRequest{ + NamespaceId: mockedNamespaceID, + Execution: &commonpb.WorkflowExecution{ + WorkflowId: ex.BusinessID, + RunId: ex.RunID, + }, + ArchetypeId: ex.ArchetypeID, + SkipForceReload: true, + })).Return(resp, err).Times(1) +} + +// TestReplicateBatch_Success exercises the full inject+verify happy path +// for one exec: the inject phase calls GenerateLastHistoryReplicationTasks +// against HistoryClient, then the verify phase's DescribeMutableState on +// the remote admin client returns OK so workflowVerifier marks it +// verified. Mirrors the legacy TestVerifyReplicationTasks_Success and +// TestGenerateReplicationTasks_Success. +func (s *activitiesSuite) TestReplicateBatch_Success() { + env, _ := s.initEnv() + s.mockNamespaceRegistry.EXPECT().GetNamespaceByID(namespace.ID(mockedNamespaceID)). + Return(&testNamespace, nil).Times(1) + + // Inject phase calls HistoryClient (generateMigrationTaskViaFrontend=false). + s.mockHistoryClient.EXPECT().GenerateLastHistoryReplicationTasks(gomock.Any(), protomock.Eq(&historyservice.GenerateLastHistoryReplicationTasksRequest{ + NamespaceId: mockedNamespaceID, + Execution: &commonpb.WorkflowExecution{ + WorkflowId: execution1.BusinessID, + RunId: execution1.RunID, + }, + ArchetypeId: execution1.ArchetypeID, + TargetClusters: []string{remoteCluster}, + })).Return(&historyservice.GenerateLastHistoryReplicationTasksResponse{}, nil).Times(1) + + // Verify phase: remote DMS returns OK; workflowVerifierProvider + // returns verified=true unconditionally so the exec verifies on the + // first pass. + s.mockRemoteAdminClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&adminservice.DescribeMutableStateRequest{ + Namespace: mockedNamespace, + Execution: &commonpb.WorkflowExecution{ + WorkflowId: execution1.BusinessID, + RunId: execution1.RunID, + }, + Archetype: chasm.WorkflowArchetype, + ArchetypeId: execution1.ArchetypeID, + SkipForceReload: true, + })).Return(&adminservice.DescribeMutableStateResponse{}, nil).Times(1) + + req := newShardedReq(payloadFor(0, execution1)) + f, err := env.ExecuteActivity(s.a.ReplicateBatch, req) + s.NoError(err) + var out replicateBatchResult + s.NoError(f.Get(&out)) + s.Equal(int64(1), out.VerifiedCount) + s.Equal([]int32{0}, out.CompletedShards) + s.Empty(out.InFlight) +} + +// TestReplicateBatch_SkipZombie exercises the retention/zombie skip path: +// remote DMS returns NotFound, source DMS returns a zombie state, and +// checkSkipWorkflowExecution marks the exec as verified-via-skip so the +// shard's verify accounting completes. Mirrors the existing +// TestVerifyReplicationTasks_SkipWorkflowExecution. +func (s *activitiesSuite) TestReplicateBatch_SkipZombie() { + env, _ := s.initEnv() + s.mockNamespaceRegistry.EXPECT().GetNamespaceByID(namespace.ID(mockedNamespaceID)). + Return(&testNamespace, nil).Times(1) + + // Resume=true so we skip inject and only exercise the verify-skip path. + req := newShardedReq(payloadFor(0, execution1)) + req.Resume = true + + s.expectRemoteNotFound(execution1) + s.expectSourceDMS(execution1, zombieState, nil) + + f, err := env.ExecuteActivity(s.a.ReplicateBatch, req) + s.NoError(err) + var out replicateBatchResult + s.NoError(f.Get(&out)) + s.Equal(int64(1), out.VerifiedCount) + s.Empty(out.InFlight) +} + +// TestReplicateBatch_SkipRetention exercises the close-time/retention skip +// path: remote DMS returns NotFound, source DMS returns a completed +// workflow whose CloseTime+Retention is in the past, so +// checkSkipWorkflowExecution marks it skipped (counted as verified). +// Mirrors the existing Test_verifyReplicationTasksSkipRetention. +func (s *activitiesSuite) TestReplicateBatch_SkipRetention() { + env, _ := s.initEnv() + + retention := time.Hour + closeTime := time.Now().Add(-2 * retention) // deleteTime is in the past + + // Build a real namespace.Namespace with a retention setting so + // checkSkipWorkflowExecution's `ns.Retention()` returns non-zero. + factory := namespace.NewDefaultReplicationResolverFactory() + detail := &persistencespb.NamespaceDetail{ + Info: &persistencespb.NamespaceInfo{}, + Config: &persistencespb.NamespaceConfig{ + Retention: durationpb.New(retention), + }, + ReplicationConfig: &persistencespb.NamespaceReplicationConfig{}, + } + ns, nsErr := namespace.FromPersistentState(detail, factory(detail)) + s.NoError(nsErr) + // Override the suite-default GetNamespaceByID for this test so the + // activity sees a namespace with retention configured. + s.mockNamespaceRegistry.EXPECT().GetNamespaceByID(namespace.ID(mockedNamespaceID)). + Return(ns, nil).Times(1) + + s.expectRemoteNotFound(execution1) + s.expectSourceDMS(execution1, &historyservice.DescribeMutableStateResponse{ + DatabaseMutableState: &persistencespb.WorkflowMutableState{ + ExecutionState: &persistencespb.WorkflowExecutionState{ + State: enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED, + }, + ExecutionInfo: &persistencespb.WorkflowExecutionInfo{ + CloseTime: timestamppb.New(closeTime), + }, + }, + }, nil) + + req := newShardedReq(payloadFor(0, execution1)) + req.Resume = true + f, err := env.ExecuteActivity(s.a.ReplicateBatch, req) + s.NoError(err) + var out replicateBatchResult + s.NoError(f.Get(&out)) + s.Equal(int64(1), out.VerifiedCount) +} + +// TestReplicateBatch_ShardNoProgress: the per-shard cumulative no-progress +// backstop fires non-retryably when a shard has gone longer than +// req.ShardNoProgress without a verified outcome. Resume=true with a +// pre-seeded NoProgressByShard pushes the shard right at the threshold +// before the first verify pass, so the first failed attempt trips the +// check immediately and we don't have to spin on wall-clock. Mirrors +// the existing TestVerifyReplicationTasks_FailedNotFound. +func (s *activitiesSuite) TestReplicateBatch_ShardNoProgress() { + env, _ := s.initEnv() + s.mockNamespaceRegistry.EXPECT().GetNamespaceByID(namespace.ID(mockedNamespaceID)). + Return(&testNamespace, nil).Times(1) + + // Remote returns BUSY_WORKFLOW so verify counts the exec as + // pending without consulting source — that keeps the shard's + // lastProgress at its seeded (already-stale) value. + s.mockRemoteAdminClient.EXPECT().DescribeMutableState(gomock.Any(), gomock.Any()). + Return(nil, &serviceerror.ResourceExhausted{ + Cause: enumspb.RESOURCE_EXHAUSTED_CAUSE_BUSY_WORKFLOW, + }).AnyTimes() + + req := newShardedReq(payloadFor(0, execution1)) + req.Resume = true // skip inject + req.ShardNoProgress = 10 * time.Millisecond // trip almost immediately + req.NoProgressByShard = map[int32]time.Duration{ // seed past threshold + 0: time.Second, + } + + _, err := env.ExecuteActivity(s.a.ReplicateBatch, req) + s.Error(err) + var appErr *temporal.ApplicationError + s.ErrorAs(err, &appErr) + s.Equal("ShardNoProgress", appErr.Type()) + s.True(appErr.NonRetryable(), "ShardNoProgress should be non-retryable") +} + +// TestReplicateBatch_Resume_SkipsInject: Resume=true should bypass the +// inject phase entirely — no GenerateLastHistoryReplicationTasks call. +// Mirrors the inject-side guarantee that the legacy +// TestVerifyReplicationTasks_AlreadyVerified asserts for verify +// (resume-via-heartbeat skips already-done work). +func (s *activitiesSuite) TestReplicateBatch_Resume_SkipsInject() { + env, _ := s.initEnv() + s.mockNamespaceRegistry.EXPECT().GetNamespaceByID(namespace.ID(mockedNamespaceID)). + Return(&testNamespace, nil).Times(1) + + // No GenerateLastHistoryReplicationTasks expectation — gomock with + // strict expectations would fail if inject called it. Verify path: + // remote DMS OK so the exec verifies on first pass. + s.mockRemoteAdminClient.EXPECT().DescribeMutableState(gomock.Any(), gomock.Any()). + Return(&adminservice.DescribeMutableStateResponse{}, nil).Times(1) + + req := newShardedReq(payloadFor(0, execution1)) + req.Resume = true + f, err := env.ExecuteActivity(s.a.ReplicateBatch, req) + s.NoError(err) + var out replicateBatchResult + s.NoError(f.Get(&out)) + s.Equal(int64(1), out.VerifiedCount) +} + +// TestReplicateBatch_DisableVerification: with verification disabled the +// activity runs inject, then returns immediately with VerifiedCount=0 and +// every batch shard listed as completed — no DMS calls. Mirrors the +// workflow-level TestSharded_DisableVerification_NoVerifiedCount but +// from the activity side. +func (s *activitiesSuite) TestReplicateBatch_DisableVerification() { + env, _ := s.initEnv() + // No NewRemoteAdminClientWithTimeout expectation — the activity + // builds the client unconditionally even in inject-only mode, so we + // still need the factory to hand back something. Return without + // expecting any DMS calls on it. + + s.mockHistoryClient.EXPECT().GenerateLastHistoryReplicationTasks(gomock.Any(), gomock.Any()). + Return(&historyservice.GenerateLastHistoryReplicationTasksResponse{}, nil).Times(1) + + req := newShardedReq(payloadFor(0, execution1)) + req.DisableVerification = true + f, err := env.ExecuteActivity(s.a.ReplicateBatch, req) + s.NoError(err) + var out replicateBatchResult + s.NoError(f.Get(&out)) + s.Equal(int64(0), out.VerifiedCount) + s.Equal([]int32{0}, out.CompletedShards) +} + +// TestReplicateBatch_HeartbeatResumesInject: a recorded NextInjectIdx +// heartbeat from a prior attempt causes the inject phase to skip +// already-injected execs. Pre-seeds heartbeat NextInjectIdx=1 across a +// two-exec batch, then asserts only the second exec's +// GenerateLastHistoryReplicationTasks is invoked. Mirrors the legacy +// TestGenerateReplicationTasks_Success_ViaFrontend's heartbeat-resume +// assertion. +func (s *activitiesSuite) TestReplicateBatch_HeartbeatResumesInject() { + env, _ := s.initEnv() + s.mockNamespaceRegistry.EXPECT().GetNamespaceByID(namespace.ID(mockedNamespaceID)). + Return(&testNamespace, nil).Times(1) + + // Pre-seed the heartbeat so inject resumes at index 1. + env.SetHeartbeatDetails(replicateBatchHeartbeat{NextInjectIdx: 1, InjectDone: false}) + + // Two execs on shard 0. flatten() orders by BID alphabetically, so + // index 0 = execution1 ("workflow1"), index 1 = execution2 ("workflow2"). + payload := BatchPayload{ + 0: { + execution1.BusinessID: {{RunID: execution1.RunID, ArchetypeID: execution1.ArchetypeID}}, + execution2.BusinessID: {{RunID: execution2.RunID, ArchetypeID: execution2.ArchetypeID}}, + }, + } + + // Only execution2 (index 1) should be injected. Strict Times(1) + // ensures execution1 is NOT injected — gomock would fail an + // unexpected execution1 call. + s.mockHistoryClient.EXPECT().GenerateLastHistoryReplicationTasks(gomock.Any(), protomock.Eq(&historyservice.GenerateLastHistoryReplicationTasksRequest{ + NamespaceId: mockedNamespaceID, + Execution: &commonpb.WorkflowExecution{ + WorkflowId: execution2.BusinessID, + RunId: execution2.RunID, + }, + ArchetypeId: execution2.ArchetypeID, + TargetClusters: []string{remoteCluster}, + })).Return(&historyservice.GenerateLastHistoryReplicationTasksResponse{}, nil).Times(1) + + // Both execs verify in one pass. + for _, ex := range []*ExecutionInfo{execution1, execution2} { + s.mockRemoteAdminClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&adminservice.DescribeMutableStateRequest{ + Namespace: mockedNamespace, + Execution: &commonpb.WorkflowExecution{ + WorkflowId: ex.BusinessID, + RunId: ex.RunID, + }, + Archetype: chasm.WorkflowArchetype, + ArchetypeId: ex.ArchetypeID, + SkipForceReload: true, + })).Return(&adminservice.DescribeMutableStateResponse{}, nil).Times(1) + } + + req := newShardedReq(payload) + f, err := env.ExecuteActivity(s.a.ReplicateBatch, req) + s.NoError(err) + var out replicateBatchResult + s.NoError(f.Get(&out)) + s.Equal(int64(2), out.VerifiedCount) +} diff --git a/service/worker/migration/sharded_types.go b/service/worker/migration/sharded_types.go new file mode 100644 index 00000000000..1de0ba06ce8 --- /dev/null +++ b/service/worker/migration/sharded_types.go @@ -0,0 +1,363 @@ +package migration + +import ( + "encoding/json" + "fmt" + "slices" + "time" +) + +const ( + // shardedForceReplicationWorkflowName is the registered workflow name. + // Distinct from the legacy ForceReplicationWorkflow so both variants + // coexist in the same worker — pick which to use at workflow start + // time by setting the start request's workflow type. + shardedForceReplicationWorkflowName = "force-replication-sharded" + + // releaseShardsSignalName carries mid-flight ReleaseShards signals + // from active replicate-batch activities back to their parent + // workflow. Drain-mode shard completions ride the activity return + // value instead, so this signal only fires while the activity is + // still running normally. + releaseShardsSignalName = "ReleaseShards" + + // defaultShardedListPageSize is the ListWorkflows page size when + // the sharded workflow's params.ListWorkflowsPageSize is unset. + defaultShardedListPageSize = 1000 + + // defaultBatchSize bounds the total executions in any single + // ReplicateBatch activity. Activity-payload sizing knob, not a + // shard-fill threshold. + defaultBatchSize = 100 + + // defaultMaxExecsPerShard bounds the executions any single shard + // can contribute to a batch — i.e. the per-shard inject blast + // radius before that shard's apply queue has to absorb a burst. + // 50 keeps a hot shard's contribution under half a default-sized + // batch (BatchSize=100), so a batch still spans ≥2 shards. + defaultMaxExecsPerShard = 50 + + // defaultShardNoProgress is the per-shard cumulative no-progress + // backstop. While a shard's pending exec count is non-zero and + // no exec on that shard has produced a verified outcome for this + // long (carried across CAN via the resume payload), the activity + // fails non-retryably naming the stuck shard. + defaultShardNoProgress = 5 * time.Minute + + // defaultDrainGrace is the wall-budget the activity gets after + // the workflow cancels it for CAN. Continues verifying until + // either the grace expires, the idle-cost trigger fires, or + // every exec verifies. + defaultDrainGrace = 15 * time.Second + + // defaultIdleShardCost is the cumulative idle-time threshold + // (the "shard-seconds" unit: 30 s with 1 idle shard equals + // 3.3 s with 9 idle) at which the activity signal-releases its + // completed-but-not-yet-released shards mid-flight. + defaultIdleShardCost = 30 * time.Second + + // defaultPerBatchGenerateRPS is the per-batch inject-phase target. + // Sharded dispatches many concurrent batches and each builds its + // own limiter, so this caps the per-batch generate-replication-task + // rate; the workflow does not normalise against a global cap the + // way the existing migration's OverallRps does. + defaultPerBatchGenerateRPS = 30.0 + + // defaultConcurrentBatchCap is the ceiling applied to the derived + // default of targetShardCount/4. Keeps the in-flight batch count + // safely inside per-worker concurrent-activity budgets and bounds + // the cluster blast radius of a single force-rep run. + defaultConcurrentBatchCap = 500 +) + +// RunEntry is the per-run leaf in the nested batch payload. Carries the +// RunID plus an optional ArchetypeID, serialised as a JSON tuple: +// `["runID"]` when ArchetypeID is zero, `["runID", N]` when set. +// +// Why a tuple, not an object: a tuple omits the JSON field names +// (`"r":`, `"a":`) that would otherwise repeat on every run, which is +// the main lever behind the nested payload's byte savings. With many +// runs per BusinessID (the heavy-reuse case), this collapses the +// per-run encoding overhead from ~47 bytes (flat ExecutionInfo) to +// ~10 bytes per run. +type RunEntry struct { + RunID string + ArchetypeID uint32 +} + +// MarshalJSON serialises RunEntry as a heterogeneous JSON tuple. Custom +// because Go's default JSON can't express a heterogeneous tuple, and +// archetype-omission needs to happen by changing the tuple length +// rather than emitting an explicit zero. UnmarshalJSON below is the +// inverse. +func (r RunEntry) MarshalJSON() ([]byte, error) { + if r.ArchetypeID == 0 { + return fmt.Appendf(nil, `[%q]`, r.RunID), nil + } + return fmt.Appendf(nil, `[%q,%d]`, r.RunID, r.ArchetypeID), nil +} + +func (r *RunEntry) UnmarshalJSON(data []byte) error { + var raw []json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return fmt.Errorf("RunEntry: %w", err) + } + if len(raw) < 1 || len(raw) > 2 { + return fmt.Errorf("RunEntry: expected [runID] or [runID,archetypeID], got %d-element array", len(raw)) + } + if err := json.Unmarshal(raw[0], &r.RunID); err != nil { + return fmt.Errorf("RunEntry runID: %w", err) + } + r.ArchetypeID = 0 + if len(raw) == 2 { + if err := json.Unmarshal(raw[1], &r.ArchetypeID); err != nil { + return fmt.Errorf("RunEntry archetypeID: %w", err) + } + } + return nil +} + +// BatchPayload groups runs by (shard, businessID) so a single businessID +// with many runs costs one BID-string-worth of bytes instead of one per +// run. The wire shape behind shardedBatchReq.Executions, ResumeShard.Execs +// (the per-shard inner map), and ShardedForceReplicationParams.RecoveredBuckets. +// +// On-wire form: +// +// {"shardID": {"businessID": [["runID"], ["runID", archetypeID], ...], ...}, ...} +// +// Top-level keys are shard IDs; inner-map keys are businessIDs; values +// are RunEntry tuples. +type BatchPayload map[int32]map[string][]RunEntry + +// totalRuns counts runs across all (shard, BID) groups. +func (p BatchPayload) totalRuns() int { + n := 0 + for _, byBID := range p { + for _, runs := range byBID { + n += len(runs) + } + } + return n +} + +// sortedShards returns shard IDs in ascending order. Used to give the +// activity-side flatten a deterministic iteration order for replays. +func (p BatchPayload) sortedShards() []int32 { + out := make([]int32, 0, len(p)) + for sh := range p { + out = append(out, sh) + } + slices.Sort(out) + return out +} + +// shardedExecutionInfo pairs an upstream ExecutionInfo with the destination +// history shard the sharded design routes by. Shard is kept out of the +// upstream ExecutionInfo struct because no other workflow needs it; the +// wire format (BatchPayload) carries shard as the outer map key. +type shardedExecutionInfo struct { + *ExecutionInfo + Shard int32 +} + +// flatten produces a deterministically-ordered slice of execs paired with +// their destination shard: shards ascending, BIDs alphabetical within +// shard, runs in input order. Lets the inner verify loop stay index-based +// even though the wire shape is nested. +func (p BatchPayload) flatten() []*shardedExecutionInfo { + n := p.totalRuns() + if n == 0 { + return nil + } + out := make([]*shardedExecutionInfo, 0, n) + for _, sh := range p.sortedShards() { + byBID := p[sh] + bids := make([]string, 0, len(byBID)) + for bid := range byBID { + bids = append(bids, bid) + } + slices.Sort(bids) + for _, bid := range bids { + for _, r := range byBID[bid] { + out = append(out, &shardedExecutionInfo{ + ExecutionInfo: &ExecutionInfo{ + BusinessID: bid, + RunID: r.RunID, + ArchetypeID: r.ArchetypeID, + }, + Shard: sh, + }) + } + } + } + return out +} + +// merge folds src into p. +func (p BatchPayload) merge(src BatchPayload) { + for sh, byBID := range src { + if p[sh] == nil { + p[sh] = map[string][]RunEntry{} + } + for bid, runs := range byBID { + p[sh][bid] = append(p[sh][bid], runs...) + } + } +} + +// ShardedForceReplicationParams is the workflow input. Configuration +// fields are read-only across CAN cycles; the carry-over block at the +// bottom is mutated each cycle. +type ShardedForceReplicationParams struct { + // ---- Configuration ---- + Namespace string + Query string + BatchSize int + MaxExecsPerShard int + ListWorkflowsPageSize int + TargetClusterName string + DisableVerification bool + + ShardNoProgress time.Duration + DrainGrace time.Duration + IdleShardCost time.Duration + + TaskQueueUserDataReplicationParams TaskQueueUserDataReplicationParams + + // PerBatchGenerateRPS is the inject-phase rate-limiter target inside + // each ReplicateBatch activity. See defaultPerBatchGenerateRPS for + // the rationale; defaults to that value. + PerBatchGenerateRPS float64 + + // ConcurrentBatchCount is the absolute ceiling on in-flight + // ReplicateBatch activities. Per-shard exclusivity already bounds + // concurrency to the target shard count, but at large cluster sizes + // that's well past the worker's concurrent-activity budget. This + // cap keeps the workflow inside that budget and limits the cluster + // blast radius of a single force-rep run. Defaults to + // min(targetShardCount/4, defaultConcurrentBatchCap). + ConcurrentBatchCount int + + // EstimationMultiplier sizes the QPSQueue's initial slice capacity + // (multiplier × ConcurrentBatchCount + 1). Pure allocation hint — + // the sliding window's logical max stays ConcurrentBatchCount + 1. + // Defaults to 2. + EstimationMultiplier int + + // ---- Continue-as-new carry-over ---- + NextPageToken []byte + ContinuedAsNewCount int + TotalForceReplicateWorkflowCount int64 + ReplicatedWorkflowCount int64 + ReplicatedWorkflowCountPerSecond float64 + + // QPSQueue carries the sliding-window samples across CAN so the + // per-second rate doesn't drop to zero on every cycle boundary. + QPSQueue QPSQueue + + // ResumeShards carries unverified execs from drained activities in + // the prior CAN cycle. The new run dispatches resume activities for + // these before the page loop runs, so their shards are claimed in + // shardInFlight from the start and the packer treats them as busy. + ResumeShards []ResumeShard + + // RecoveredBuckets carries execs whose dispatching activity returned + // a cancellation without returning a result — i.e., the activity + // body never ran (cancel-before-start race). They were dispatched + // but never injected, so the new cycle restores them into the + // streaming buckets to be dispatched as fresh inject+verify batches. + RecoveredBuckets BatchPayload + + TaskQueueUserDataReplicationStatus TaskQueueUserDataReplicationStatus +} + +// ResumeShard carries one shard's worth of unverified execs from a drained +// activity across a CAN boundary to the resume activity that picks them +// up. NoProgressDuration is the cumulative time the shard went without a +// verified outcome at drain time; the resume activity initialises its own +// per-shard last-progress clock to (now - NoProgressDuration) so the +// backstop check sees the full elapsed no-progress window, not just the +// current activity's slice. +// +// Execs is keyed by businessID: each entry is a list of RunEntry tuples +// for that BID. Grouping by BID at the wire level lets a hot BID (with +// many runs) collapse to one BID-string + N tuples rather than N copies +// of the BID; see BatchPayload's docstring. +type ResumeShard struct { + Shard int32 + Execs map[string][]RunEntry + NoProgressDuration time.Duration +} + +// shardedBatchReq is the per-batch activity input. Executions is the +// per-shard, per-BID nested payload — the workflow has marked every +// shard appearing as a top-level key in shardInFlight before dispatch, +// and the activity is responsible for either signal-releasing each shard +// mid-flight or listing it in the return value's CompletedShards / InFlight +// set. +// +// Resume=true skips the inject phase: the execs were already injected by +// some earlier activity that was cancelled at drain time and returned its +// unverified execs in its result. NoProgressByShard carries the cumulative +// pre-resume no-progress duration so the per-shard backstop stays +// meaningful across resume cycles. +type shardedBatchReq struct { + BatchID int64 + Namespace string + NamespaceID string + Executions BatchPayload + + TargetClusterName string + + Resume bool + DisableVerification bool + NoProgressByShard map[int32]time.Duration + + PerBatchGenerateRPS float64 + + ShardNoProgress time.Duration + DrainGrace time.Duration + IdleShardCost time.Duration +} + +// replicateBatchResult is the activity's return payload. The activity is +// the source of truth for which execs verified vs. are still outstanding +// when it returns — only it has the per-exec verify state — so the drain +// payload rides the return value rather than a signal. The workflow's +// dispatch coroutine reads InFlight into drainPayload on nil-error return. +// +// CompletedShards is informational (the dispatch coroutine's defer clears +// heldByBatch + shardInFlight regardless), but keeping it in the result +// gives metrics a clean handle on "which shards this batch finished". +type replicateBatchResult struct { + CompletedShards []int32 + InFlight []ResumeShard + + // VerifiedCount is the number of executions this activity invocation + // finished verifying (including retention/zombie skips that resolve + // as verified). The workflow accumulates this into its running + // ReplicatedWorkflowCount and emits the per-batch delta as the + // replicated_workflow_count counter. + VerifiedCount int64 +} + +type replicateBatchHeartbeat struct { + // NextInjectIdx is the index of the next exec to inject on retry. + NextInjectIdx int + // InjectDone marks the inject phase as complete; retries skip inject. + InjectDone bool +} + +// releaseShardsPayload is the body of the mid-flight ReleaseShards signal +// an activity sends to its parent workflow when the cumulative idle cost +// across its completed-but-not-yet-released shards crosses IdleShardCost. +// The workflow handler clears these shards from shardInFlight + +// heldByBatch[BatchID] so the packer can immediately dispatch new work +// against them while the activity stays running on its still-pending +// shards. Only fires in normal mode — once the activity enters drain mode +// it returns its remaining state via the activity result instead. +type releaseShardsPayload struct { + BatchID int64 + Shards []int32 +} diff --git a/service/worker/migration/sharded_types_test.go b/service/worker/migration/sharded_types_test.go new file mode 100644 index 00000000000..05c85b8c0b4 --- /dev/null +++ b/service/worker/migration/sharded_types_test.go @@ -0,0 +1,91 @@ +package migration + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestRunEntry_JSONRoundTrip pins down the tuple-array wire shape: +// 1-element when ArchetypeID is zero, 2-element otherwise. +func TestRunEntry_JSONRoundTrip(t *testing.T) { + tests := []struct { + name string + entry RunEntry + expect string + }{ + {"zero archetype is omitted", RunEntry{RunID: "r1"}, `["r1"]`}, + {"non-zero archetype is included", RunEntry{RunID: "r1", ArchetypeID: 42}, `["r1",42]`}, + {"escaped runID", RunEntry{RunID: `r"1`}, `["r\"1"]`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b, err := json.Marshal(tt.entry) + require.NoError(t, err) + require.Equal(t, tt.expect, string(b)) + + var out RunEntry + require.NoError(t, json.Unmarshal(b, &out)) + require.Equal(t, tt.entry, out) + }) + } +} + +// TestRunEntry_UnmarshalRejectsBadShape: the marshaller emits 1- or +// 2-element tuples only; anything else is a protocol violation and +// must surface a clear error rather than a silent zero value. +func TestRunEntry_UnmarshalRejectsBadShape(t *testing.T) { + cases := []string{ + `[]`, + `["r1", 1, 2]`, + `{"r": "r1"}`, + `"r1"`, + } + for _, in := range cases { + t.Run(in, func(t *testing.T) { + var out RunEntry + require.Error(t, json.Unmarshal([]byte(in), &out)) + }) + } +} + +// TestBatchPayload_JSONRoundTrip exercises the nested wire shape end +// to end so a change to either RunEntry or the surrounding map type +// can't silently regress it. +func TestBatchPayload_JSONRoundTrip(t *testing.T) { + p := BatchPayload{ + 7: {"bid-a": {{RunID: "r1"}, {RunID: "r2", ArchetypeID: 5}}}, + 8: {"bid-b": {{RunID: "r3"}}}, + } + b, err := json.Marshal(p) + require.NoError(t, err) + // Go's encoding/json sorts map keys, so this output is stable. + require.JSONEq(t, `{ + "7": {"bid-a": [["r1"], ["r2", 5]]}, + "8": {"bid-b": [["r3"]]} + }`, string(b)) + + var out BatchPayload + require.NoError(t, json.Unmarshal(b, &out)) + require.Equal(t, p, out) +} + +// TestBatchPayload_Flatten orders by shard ascending then BID +// alphabetical; runs within a BID keep input order. The activity +// inner loop depends on this for deterministic replays. +func TestBatchPayload_Flatten(t *testing.T) { + p := BatchPayload{ + 2: {"b-z": {{RunID: "rz"}}, "b-a": {{RunID: "ra1"}, {RunID: "ra2"}}}, + 1: {"b-c": {{RunID: "rc"}}}, + } + got := p.flatten() + require.Len(t, got, 4) + require.Equal(t, int32(1), got[0].Shard) + require.Equal(t, "b-c", got[0].BusinessID) + require.Equal(t, int32(2), got[1].Shard) + require.Equal(t, "b-a", got[1].BusinessID) + require.Equal(t, "ra1", got[1].RunID) + require.Equal(t, "ra2", got[2].RunID) + require.Equal(t, "b-z", got[3].BusinessID) +} diff --git a/service/worker/migration/sharded_workflow.go b/service/worker/migration/sharded_workflow.go new file mode 100644 index 00000000000..5e856878c7d --- /dev/null +++ b/service/worker/migration/sharded_workflow.go @@ -0,0 +1,1173 @@ +package migration + +import ( + "errors" + "fmt" + "slices" + "time" + + enumspb "go.temporal.io/api/enums/v1" + "go.temporal.io/api/workflowservice/v1" + sdkclient "go.temporal.io/sdk/client" + "go.temporal.io/sdk/temporal" + "go.temporal.io/sdk/workflow" + "go.temporal.io/server/common" + "go.temporal.io/server/common/metrics" +) + +// ShardedForceReplicationWorkflow runs the sharded design for one CAN +// cycle: dispatch any resume activities carried over from the prior +// cycle, then page through ListWorkflows until either the namespace +// exhausts or workflow.GetContinueAsNewSuggested(ctx) trips, bucketing +// each execution by destination history shard and dispatching a +// paired inject+verify activity once a bucket reaches packing +// eligibility. At cycle end, every remaining bucket flushes as +// packed activities. +// +// If there are more pages, in-flight activities are cancelled (giving +// them DrainGrace to drain), their drain payload arrives via the +// activity return value, and the workflow CANs with NextPageToken + +// ResumeShards in the carry-over. Otherwise it waits for activities +// to finish naturally and returns nil. +func ShardedForceReplicationWorkflow(ctx workflow.Context, params ShardedForceReplicationParams) error { + // Page token at workflow entry — returned by the status query as + // PageTokenForRestart so tooling that already knows the legacy + // "restart from the starting position" semantic keeps working. + // The richer Recovery* fields below carry the current page token + // plus in-flight execs and are what the sharded restart flow + // actually uses. + startPageToken := params.NextPageToken + + // state is assigned after newShardedWorkflowState below; the + // query handler closes over the pointer so it sees the live state + // once setup completes. Queries that arrive during setup return + // the static fields without the recovery bundle, which matches + // the prior behaviour. + var state *shardedWorkflowState + + // Register the status query under the same name upstream uses + // (forceReplicationStatusQueryType = "force-replication-status") + // so tooling that polls force-rep progress works across both + // workflow variants. + if err := workflow.SetQueryHandler(ctx, forceReplicationStatusQueryType, func() (ForceReplicationStatus, error) { + status := ForceReplicationStatus{ + ContinuedAsNewCount: params.ContinuedAsNewCount, + TotalWorkflowCount: params.TotalForceReplicateWorkflowCount, + ReplicatedWorkflowCount: params.ReplicatedWorkflowCount, + ReplicatedWorkflowCountPerSecond: params.ReplicatedWorkflowCountPerSecond, + PageTokenForRestart: startPageToken, + TaskQueueUserDataReplicationStatus: params.TaskQueueUserDataReplicationStatus, + RecoveryNextPageToken: params.NextPageToken, + RecoveryResumeShards: params.ResumeShards, + RecoveryBuckets: params.RecoveredBuckets, + } + if state != nil { + status.RecoveryResumeShards = state.collectResumeShardsForCarryover() + status.RecoveryBuckets = state.collectRecoveredBucketsForCarryover() + } + return status, nil + }); err != nil { + return err + } + + if err := validateShardedForceReplicationParams(¶ms); err != nil { + return err + } + + var err error + state, err = newShardedWorkflowState(ctx, ¶ms) + if err != nil { + return err + } + // Defaults are now applied; reject configurations the packer + // can't honour. MaxExecsPerShard > BatchSize is meaningless — + // each batch caps at BatchSize total, so the per-shard cap + // can't exceed the whole-batch cap. + if params.MaxExecsPerShard > params.BatchSize { + return temporal.NewNonRetryableApplicationError( + fmt.Sprintf("MaxExecsPerShard (%d) must be <= BatchSize (%d)", params.MaxExecsPerShard, params.BatchSize), + "InvalidConfiguration", nil) + } + + // On the first cycle, populate TotalForceReplicateWorkflowCount + // via the same CountWorkflow activity upstream uses. Skipped on + // subsequent CAN cycles — the count carries across via params. + if params.TotalForceReplicateWorkflowCount == 0 { + wfCount, err := shardedCountWorkflowsForReplication(ctx, ¶ms) + if err != nil { + return err + } + params.TotalForceReplicateWorkflowCount = wfCount + } + + if !params.TaskQueueUserDataReplicationStatus.Done { + if err := maybeKickoffShardedTaskQueueUserDataReplication(ctx, ¶ms, func(failureReason string) { + params.TaskQueueUserDataReplicationStatus.FailureMessage = failureReason + params.TaskQueueUserDataReplicationStatus.Done = true + }); err != nil { + return err + } + } + + if err := state.run(ctx); err != nil { + return err + } + + // state.run returned nil only on the terminal cycle (no more pages, + // no errors). On CAN cycles it returns the CAN error, so we never + // reach here mid-replication. + if err := workflow.Await(ctx, func() bool { return params.TaskQueueUserDataReplicationStatus.Done }); err != nil { + return err + } + if params.TaskQueueUserDataReplicationStatus.FailureMessage != "" { + return fmt.Errorf("task queue user data replication failed: %v", params.TaskQueueUserDataReplicationStatus.FailureMessage) + } + return nil +} + +func validateShardedForceReplicationParams(params *ShardedForceReplicationParams) error { + if len(params.Namespace) == 0 { + return temporal.NewNonRetryableApplicationError("InvalidArgument: Namespace is required", "InvalidArgument", nil) + } + if len(params.TargetClusterName) == 0 { + return temporal.NewNonRetryableApplicationError("InvalidArgument: TargetClusterName is required", "InvalidArgument", nil) + } + return nil +} + +func maybeKickoffShardedTaskQueueUserDataReplication(ctx workflow.Context, params *ShardedForceReplicationParams, onDone func(failureReason string)) error { + workflow.Go(ctx, func(ctx workflow.Context) { + ch := workflow.GetSignalChannel(ctx, taskQueueUserDataReplicationDoneSignalType) + var errStr string + _ = ch.Receive(ctx, &errStr) + onDone(errStr) + }) + + if params.ContinuedAsNewCount > 0 { + return nil + } + + options := workflow.ChildWorkflowOptions{ + WorkflowID: fmt.Sprintf("%s-task-queue-user-data-replicator", workflow.GetInfo(ctx).WorkflowExecution.ID), + ParentClosePolicy: enumspb.PARENT_CLOSE_POLICY_ABANDON, + } + childCtx := workflow.WithChildOptions(ctx, options) + input := TaskQueueUserDataReplicationParamsWithNamespace{ + TaskQueueUserDataReplicationParams: params.TaskQueueUserDataReplicationParams, + Namespace: params.Namespace, + } + child := workflow.ExecuteChildWorkflow(childCtx, ForceTaskQueueUserDataReplicationWorkflow, input) + var childExecution workflow.Execution + return child.GetChildWorkflowExecution().Get(ctx, &childExecution) +} + +// shardedCountWorkflowsForReplication asks the frontend how many +// workflows match the namespace's force-rep query. Used once at +// workflow start to seed TotalForceReplicateWorkflowCount for the +// status query's progress reporting. +func shardedCountWorkflowsForReplication(ctx workflow.Context, params *ShardedForceReplicationParams) (int64, error) { + ao := workflow.ActivityOptions{ + StartToCloseTimeout: 2 * time.Minute, + RetryPolicy: forceReplicationActivityRetryPolicy, + } + var a *activities + var output countWorkflowResponse + if err := workflow.ExecuteActivity( + workflow.WithActivityOptions(ctx, ao), + a.CountWorkflow, + &workflowservice.CountWorkflowExecutionsRequest{ + Namespace: params.Namespace, + Query: params.Query, + }).Get(ctx, &output); err != nil { + return 0, err + } + return output.WorkflowCount, nil +} + +// shardedWorkflowState holds the workflow's per-run state. Workflow +// coroutines yield only at SDK calls, so plain maps + ints are safe +// without mutexes — workflow.Await re-evaluates its predicate after +// each yield, which is what makes the shard-in-flight bookkeeping +// drive each dispatch coroutine's wait. +type shardedWorkflowState struct { + params *ShardedForceReplicationParams + + namespaceID string + + // targetShardCount is the target cluster's history shard count, + // fetched once via DescribeTargetCluster at state construction. + // Drives the per-exec shard hash (so packing groups execs by their + // destination shard) and the default ConcurrentBatchCount. + targetShardCount int32 + + // buckets accumulate execs that have been listed but not yet + // dispatched. Nested by destination shard then businessID, so a + // hot BID's many runs share one BID-string-worth of bytes when + // the bucket is shipped over the wire. + buckets BatchPayload + + // bucketCounts mirrors len of all runs across BIDs for each + // shard. Kept as a sidecar so the packer's per-shard ordering + // decisions are O(1) rather than O(#BIDs in shard); it's + // consulted many times per cycle. + bucketCounts map[int32]int + + // shardInFlight is the per-shard exclusivity set: a shard's + // entry is set when it's part of any in-flight batch and + // cleared when that batch returns (either fully or via mid-flight + // signal-release). Concurrent batches are limited only by this + // set — there is no global slot cap. + shardInFlight map[int32]bool + + // heldByBatch tracks per-batch shard ownership. spawnBatch + // populates it with the batch's claimed shards; the signal + // handler removes entries as shards are released mid-flight; + // the dispatch coroutine's defer clears whatever's left after + // the activity returns. Required because a signal-released + // shard may have been re-claimed by a subsequent batch — the + // returning original batch must only clear its own remaining + // claims, not stomp on the new claimant. + heldByBatch map[int64]map[int32]bool + + // activityCtx is a cancellable child of run()'s ctx; every + // dispatched batch activity is derived from it. cancelActivities + // is the matching cancel func — drainForCAN calls it once to + // cancel every in-flight batch at once. Cancelling activityCtx + // leaves the workflow's main ctx alive so the drain loop's + // Await keeps running. + activityCtx workflow.Context + cancelActivities workflow.CancelFunc + + // batchExecs tracks the input payload of each in-flight batch. + // Cleared on any nil-error return (drained execs are folded + // into drainPayload from the activity result; cleanly completed + // batches return an empty InFlight). Anything left at CAN time + // corresponds to a batch whose activity returned CanceledError + // with no result — i.e. the activity body never ran. Those + // execs are recovered into the next cycle's streaming buckets. + batchExecs map[int64]BatchPayload + + // pendingDispatches counts spawned dispatch coroutines that + // have not yet returned. The main coroutine waits on this + // dropping to zero before issuing CAN or returning. + pendingDispatches int + + // drainPayload accumulates ResumeShard entries from drained + // activities (via the activity result on nil-error return). + // Fed into the CAN-carry-over params at the end. + drainPayload []ResumeShard + + // lastErr stops further dispatch once an activity errors out + // (e.g. ShardNoProgress). Without it the workflow would keep + // paging through ListWorkflows after a broken namespace surfaces, + // burning the rest of the population before returning the + // failure. + lastErr error + + nextBatchID int64 + + // metricsHandler is tagged with the workflow's fixed scope + + // namespace once at state construction; recordVerified reuses it on + // every batch return. + metricsHandler sdkclient.MetricsHandler +} + +func newShardedWorkflowState(ctx workflow.Context, params *ShardedForceReplicationParams) (*shardedWorkflowState, error) { + ao := workflow.ActivityOptions{ + StartToCloseTimeout: 24 * time.Hour, + HeartbeatTimeout: time.Minute, + RetryPolicy: forceReplicationActivityRetryPolicy, + } + metaCtx := workflow.WithActivityOptions(ctx, ao) + var a *activities + var md MetadataResponse + if err := workflow.ExecuteActivity(metaCtx, a.GetMetadata, MetadataRequest{Namespace: params.Namespace}).Get(ctx, &md); err != nil { + return nil, err + } + var targetMd DescribeTargetClusterResponse + if err := workflow.ExecuteActivity(metaCtx, a.DescribeTargetCluster, DescribeTargetClusterRequest{ + TargetClusterName: params.TargetClusterName, + }).Get(ctx, &targetMd); err != nil { + return nil, err + } + if targetMd.ShardCount <= 0 { + return nil, temporal.NewNonRetryableApplicationError( + fmt.Sprintf("DescribeTargetCluster returned non-positive ShardCount (%d) for target %q", targetMd.ShardCount, params.TargetClusterName), + "InvalidTargetShardCount", nil) + } + if params.BatchSize <= 0 { + params.BatchSize = defaultBatchSize + } + if params.MaxExecsPerShard <= 0 { + params.MaxExecsPerShard = defaultMaxExecsPerShard + } + if params.ShardNoProgress <= 0 { + params.ShardNoProgress = defaultShardNoProgress + } + if params.DrainGrace <= 0 { + params.DrainGrace = defaultDrainGrace + } + if params.IdleShardCost <= 0 { + params.IdleShardCost = defaultIdleShardCost + } + if params.ListWorkflowsPageSize <= 0 { + params.ListWorkflowsPageSize = defaultShardedListPageSize + } + if params.PerBatchGenerateRPS <= 0 { + params.PerBatchGenerateRPS = defaultPerBatchGenerateRPS + } + if params.ConcurrentBatchCount <= 0 { + params.ConcurrentBatchCount = defaultConcurrentBatchCount(targetMd.ShardCount) + } + if params.EstimationMultiplier <= 0 { + params.EstimationMultiplier = 2 + } + // QPSQueue is sized off ConcurrentBatchCount (one sample slot per + // expected in-flight batch + one for the starting count). Seeded + // with the current ReplicatedWorkflowCount so the very first + // post-CAN batch return has a baseline to compute the rate against. + if params.QPSQueue.Data == nil { + params.QPSQueue = NewQPSQueue(params.ConcurrentBatchCount, params.EstimationMultiplier) + params.QPSQueue.Enqueue(ctx, params.ReplicatedWorkflowCount) + } + s := &shardedWorkflowState{ + params: params, + namespaceID: md.NamespaceID, + targetShardCount: targetMd.ShardCount, + buckets: BatchPayload{}, + bucketCounts: map[int32]int{}, + shardInFlight: map[int32]bool{}, + heldByBatch: map[int64]map[int32]bool{}, + batchExecs: map[int64]BatchPayload{}, + metricsHandler: workflow.GetMetricsHandler(ctx).WithTags(map[string]string{ + metrics.OperationTagName: metrics.MigrationWorkflowScope, + NamespaceTagName: params.Namespace, + }), + } + // Restore execs recovered from cancel-before-start batches in + // the prior cycle so the streaming packer picks them up + // alongside any new pages. + s.buckets.merge(params.RecoveredBuckets) + for sh, byBID := range params.RecoveredBuckets { + for _, runs := range byBID { + s.bucketCounts[sh] += len(runs) + } + } + params.RecoveredBuckets = nil + return s, nil +} + +func (s *shardedWorkflowState) run(ctx workflow.Context) error { + // Cancellable child ctx for all dispatched batch activities. + // drainForCAN cancels it once to drain in-flight batches without + // touching the workflow's main ctx (which the drain loop's + // Await still rides on). + actCtx, cancelAll := workflow.WithCancel(ctx) + defer cancelAll() + s.activityCtx = actCtx + s.cancelActivities = cancelAll + + // Start the signal handler coroutine first so any signal + // arriving during resume dispatch or page-loop drains is + // processed promptly. + workflow.Go(ctx, s.handleReleaseSignals) + + // Dispatch resume activities carried over from the prior cycle. + // Done before the page loop so their shards are claimed in + // shardInFlight before any new pages arrive — keeps the packer + // from racing to dispatch against them with fresh execs. + s.dispatchResumeBatches(ctx) + + // Drive ListWorkflows until either we exhaust the namespace or + // the SDK signals that history is large enough to CAN. Errors + // here latch into lastErr and fall through to the unified exit + // funnel — same drain-and-decide path as activity-driven errors. + for !workflow.GetInfo(ctx).GetContinueAsNewSuggested() { + if s.lastErr != nil { + break + } + executions, nextPageToken, err := s.listWorkflowPage(ctx) + if err != nil { + s.setLastErr(err) + break + } + for _, ex := range executions { + sh := common.WorkflowIDToHistoryShard(s.namespaceID, ex.BusinessID, s.targetShardCount) + s.addToBucket(sh, ex.BusinessID, RunEntry{ + RunID: ex.RunID, + ArchetypeID: ex.ArchetypeID, + }) + } + s.params.NextPageToken = nextPageToken + + for s.tryPackStreaming(ctx, false) { //nolint:revive // intentional empty body + } + + if len(nextPageToken) == 0 { + break + } + } + + // Drain remaining buckets only when no error has latched — on + // error we deliberately stop scheduling new work and let the + // already-dispatched batches finish via awaitInFlightCompletion. + if s.lastErr == nil { + s.drainBuckets(ctx) + } + + // Wait for in-flight activities. Cancels them when we already + // know we're failing or CAN-ing; on the clean success path + // (no error, no more pages) waits naturally so a healthy + // activity isn't cancelled into a CanceledError that masquerades + // as carry-over state. + s.awaitInFlightCompletion(ctx) + + // Single exit decision. Recovery state has the same shape on + // either path — drainPayload + undispatched ResumeShards + + // batchExecs + leftover buckets — so the only difference between + // "fail with state" and "CAN with state" is the return value. + if s.lastErr != nil { + return s.lastErr + } + if !s.hasCarryover() { + return nil + } + + next := *s.params + next.ContinuedAsNewCount++ + next.ResumeShards = s.collectResumeShardsForCarryover() + next.RecoveredBuckets = s.collectRecoveredBucketsForCarryover() + return workflow.NewContinueAsNewError(ctx, ShardedForceReplicationWorkflow, next) +} + +var ( + shardedListWorkflowsRetryPolicy = &temporal.RetryPolicy{ + InitialInterval: time.Second, + BackoffCoefficient: 2.0, + MaximumAttempts: 3, + } + shardedListWorkflowsActivityOptions = workflow.ActivityOptions{ + StartToCloseTimeout: time.Hour, + RetryPolicy: shardedListWorkflowsRetryPolicy, + } + + // Per-exec backoff still owns the per-exec retry; MaximumAttempts + // lets a transient activity failure recover via heartbeat-resume + // without losing inject progress. WaitForCancellation lets a + // cancelled activity run drain logic and return its drain result. + shardedReplicateBatchRetryPolicy = &temporal.RetryPolicy{ + MaximumAttempts: 3, + } + shardedReplicateBatchActivityOptions = workflow.ActivityOptions{ + StartToCloseTimeout: 24 * time.Hour, + HeartbeatTimeout: time.Minute, + RetryPolicy: shardedReplicateBatchRetryPolicy, + WaitForCancellation: true, + } +) + +func (s *shardedWorkflowState) listWorkflowPage(ctx workflow.Context) ([]*ExecutionInfo, []byte, error) { + listCtx := workflow.WithActivityOptions(ctx, shardedListWorkflowsActivityOptions) + listReq := &workflowservice.ListWorkflowExecutionsRequest{ + Namespace: s.params.Namespace, + Query: s.params.Query, + PageSize: int32(s.params.ListWorkflowsPageSize), + NextPageToken: s.params.NextPageToken, + } + var a *activities + var listResp listWorkflowsResponse + if err := workflow.ExecuteActivity(listCtx, a.ListWorkflows, listReq).Get(ctx, &listResp); err != nil { + return nil, nil, err + } + return listResp.Executions, listResp.NextPageToken, nil +} + +func (s *shardedWorkflowState) replicateBatch(ctx, activityParentCtx workflow.Context, req *shardedBatchReq) (replicateBatchResult, error) { + actx := workflow.WithActivityOptions(activityParentCtx, shardedReplicateBatchActivityOptions) + var a *activities + var result replicateBatchResult + if err := workflow.ExecuteActivity(actx, a.ReplicateBatch, req).Get(ctx, &result); err != nil { + return replicateBatchResult{}, err + } + return result, nil +} + +// awaitInFlightCompletion drains in-flight batches before the workflow +// exits. The strategy depends on what we already know: +// +// - Error latched or more pages remain (we're going to fail or CAN +// either way): cancel immediately via drainForCAN, bounded by +// DrainGrace + IdleShardCost. No point waiting for activities +// that are going to be discarded. +// - Clean success path (no error, no more pages): wait for natural +// completion so a healthy activity's clean result isn't masked +// as a CanceledError. If an activity hits its ShardNoProgress +// backstop mid-wait, latch the error then cancel the rest fast +// rather than waiting on every shard's backstop too. +func (s *shardedWorkflowState) awaitInFlightCompletion(ctx workflow.Context) { + if s.pendingDispatches == 0 { + return + } + if s.lastErr != nil || len(s.params.NextPageToken) > 0 { + s.drainForCAN(ctx) + return + } + _ = workflow.Await(ctx, func() bool { + return s.pendingDispatches == 0 || s.lastErr != nil + }) + if s.pendingDispatches > 0 { + s.drainForCAN(ctx) + } +} + +// hasCarryover reports whether the workflow has any state worth +// preserving across an exit — either a remaining page token, drained +// execs from in-flight batches, cancel-before-start batches that +// never injected, undispatched resume entries, or listed-but-unpacked +// execs. Drives both the "CAN vs return nil" decision and the +// recovery bundle exposed in the status query. +func (s *shardedWorkflowState) hasCarryover() bool { + if len(s.params.NextPageToken) > 0 { + return true + } + if len(s.drainPayload) > 0 { + return true + } + if len(s.batchExecs) > 0 { + return true + } + if len(s.params.ResumeShards) > 0 { + return true + } + return !s.bucketsEmpty() +} + +// collectResumeShardsForCarryover concatenates this cycle's drained +// execs with any prior-cycle ResumeShards that didn't get dispatched +// (left in params.ResumeShards by dispatchResumeBatches when it bailed +// out on lastErr). Both groups are already shard-keyed; the next +// cycle's dispatchResumeBatches sorts and re-packs them. +func (s *shardedWorkflowState) collectResumeShardsForCarryover() []ResumeShard { + if len(s.drainPayload) == 0 && len(s.params.ResumeShards) == 0 { + return nil + } + out := make([]ResumeShard, 0, len(s.drainPayload)+len(s.params.ResumeShards)) + out = append(out, s.drainPayload...) + out = append(out, s.params.ResumeShards...) + return out +} + +// collectRecoveredBucketsForCarryover merges the two sources of +// "execs that never made it through a verify activity this cycle": +// batches that returned CanceledError without running a body, and +// listed-but-unpacked execs still sitting in s.buckets when the +// workflow exited (either lastErr stopped the streaming packer or +// drainBuckets bailed out on lastErr partway through). +func (s *shardedWorkflowState) collectRecoveredBucketsForCarryover() BatchPayload { + out := collectRecoveredBuckets(s.batchExecs) + if !s.bucketsEmpty() { + if out == nil { + out = BatchPayload{} + } + out.merge(s.buckets) + } + return out +} + +// recordVerified accumulates one batch's verified-exec delta into the +// workflow's running count, emits the per-batch counter delta, and +// updates the sliding-window RPS gauge. No-op when verified == 0 so a +// batch that ran entirely as drain-no-progress doesn't poison the +// QPSQueue with a zero-delta sample. +func (s *shardedWorkflowState) recordVerified(ctx workflow.Context, verified int64) { + if verified <= 0 { + return + } + s.params.ReplicatedWorkflowCount += verified + + s.metricsHandler.Counter(metrics.ReplicatedWorkflowCount.Name()).Inc(verified) + + s.params.QPSQueue.Enqueue(ctx, s.params.ReplicatedWorkflowCount) + s.params.ReplicatedWorkflowCountPerSecond = s.params.QPSQueue.CalculateQPS() + s.metricsHandler.Gauge(ForceReplicationRpsTagName).Update(s.params.ReplicatedWorkflowCountPerSecond) +} + +// defaultConcurrentBatchCount derives the in-flight-batch ceiling +// from the target cluster's shard count: a quarter of the shards, +// capped at defaultConcurrentBatchCap. The 1/4 fraction leaves worker +// slots free for unrelated activities; the absolute cap bounds the +// cluster blast radius regardless of cluster size. Returns at least 1. +func defaultConcurrentBatchCount(shards int32) int { + return max(min(int(shards)/4, defaultConcurrentBatchCap), 1) +} + +// collectRecoveredBuckets re-buckets any execs from batches whose +// dispatching activity returned CanceledError without returning a +// result — i.e. the activity body never ran, so its execs were +// never injected. They go back into the next cycle's streaming +// buckets to be dispatched as fresh inject+verify batches. The +// shard is the top-level map key on each batch's payload, so no +// re-hashing here — collectRecoveredBuckets just merges. +func collectRecoveredBuckets(batchExecs map[int64]BatchPayload) BatchPayload { + if len(batchExecs) == 0 { + return nil + } + out := BatchPayload{} + for _, bp := range batchExecs { + out.merge(bp) + } + return out +} + +// addToBucket appends one run to the (shard, BID) bucket and bumps +// the sidecar count. +func (s *shardedWorkflowState) addToBucket(shard int32, businessID string, run RunEntry) { + if s.buckets[shard] == nil { + s.buckets[shard] = map[string][]RunEntry{} + } + s.buckets[shard][businessID] = append(s.buckets[shard][businessID], run) + s.bucketCounts[shard]++ +} + +// takeFromBucket consumes up to n runs from the given shard and +// returns them grouped by BID. Walks BIDs in alphabetical order so +// the resulting payload is deterministic across replays; takes whole +// per-BID runs only as needed to reach n. Empties the shard from +// s.buckets / s.bucketCounts when nothing remains. +func (s *shardedWorkflowState) takeFromBucket(shard int32, n int) map[string][]RunEntry { + if n <= 0 { + return nil + } + byBID := s.buckets[shard] + if len(byBID) == 0 { + return nil + } + bids := make([]string, 0, len(byBID)) + for bid := range byBID { + bids = append(bids, bid) + } + slices.Sort(bids) + + out := map[string][]RunEntry{} + taken := 0 + for _, bid := range bids { + if taken >= n { + break + } + runs := byBID[bid] + take := min(len(runs), n-taken) + // append([]RunEntry(nil), ...) gives the output its own + // backing array — keeps the workflow's leftover slice + // (byBID[bid][take:]) and the activity's input independent + // in case either side appends later. + out[bid] = append([]RunEntry(nil), runs[:take]...) + if take == len(runs) { + delete(byBID, bid) + } else { + byBID[bid] = runs[take:] + } + taken += take + } + s.bucketCounts[shard] -= taken + if s.bucketCounts[shard] <= 0 { + delete(s.bucketCounts, shard) + delete(s.buckets, shard) + } + return out +} + +// handleReleaseSignals runs as a long-lived workflow coroutine, +// consuming ReleaseShards signals from in-flight activities. Each +// signal lists shards the activity considers complete; the handler +// clears them from heldByBatch[BatchID] (so the dispatch coroutine's +// defer won't double-release) and shardInFlight (so the packer can +// dispatch new work against them while the activity stays running on +// its still-pending shards). +// +// DO NOT add workflow yields (ExecuteActivity, Sleep, Await, etc.) +// between Receive and the next Receive. drainForCAN relies on +// ch.Len() == 0 implying "every delivered signal has been processed"; +// a yield mid-handler would invalidate that, leaving shardInFlight +// stale after a CAN. +func (s *shardedWorkflowState) handleReleaseSignals(ctx workflow.Context) { + ch := workflow.GetSignalChannel(ctx, releaseShardsSignalName) + for ctx.Err() == nil { + var payload releaseShardsPayload + if !ch.Receive(ctx, &payload) { + return + } + held, ok := s.heldByBatch[payload.BatchID] + if !ok { + continue + } + for _, sh := range payload.Shards { + if held[sh] { + delete(held, sh) + delete(s.shardInFlight, sh) + } + } + } +} + +// extractVerifiedCountFromError pulls the partial VerifiedCount that +// wrapBatchVerifyError encoded into a BatchVerifyPartial-typed +// ApplicationError's Details on the activity side. Returns 0 when +// the error didn't come through the verify-phase wrapper (e.g. +// inject-phase failures, ctx errors, non-ApplicationError types), so +// callers can unconditionally fold the result into recordVerified. +func extractVerifiedCountFromError(err error) int64 { + if err == nil { + return 0 + } + appErr, ok := errors.AsType[*temporal.ApplicationError](err) + if !ok || appErr.Type() != batchVerifyPartialErrorType { + return 0 + } + var count int64 + if appErr.Details(&count) != nil { + return 0 + } + return count +} + +// setLastErr latches the first error encountered. Subsequent errors +// are dropped so the root cause is preserved for the workflow's +// returned failure — without the latch, a stuck-shard backstop firing +// on every batch as the workflow tears down would overwrite the +// genuinely interesting first failure. +func (s *shardedWorkflowState) setLastErr(err error) { + if s.lastErr == nil { + s.lastErr = err + } +} + +// dispatchSlotAvailable returns true when the workflow is below the +// in-flight batch ceiling and is free to spawn another batch. Callers +// that can defer dispatch (the streaming packer) consult this and +// bail out; callers that must dispatch (resume payloads) pair it +// with waitForDispatchSlot. ConcurrentBatchCount is normalised to +// >= 1 at state construction, so no zero-disable path is needed. +func (s *shardedWorkflowState) dispatchSlotAvailable() bool { + return s.pendingDispatches < s.params.ConcurrentBatchCount +} + +// waitForDispatchSlot blocks the calling workflow coroutine until a +// dispatch slot frees up or lastErr trips. +func (s *shardedWorkflowState) waitForDispatchSlot(ctx workflow.Context) { + _ = workflow.Await(ctx, func() bool { + return s.lastErr != nil || s.pendingDispatches < s.params.ConcurrentBatchCount + }) +} + +// resumeBatch is one packed dispatch plan: the BatchPayload that will +// become a batch's input, plus the matching per-shard no-progress +// durations. Built up front by packResumeBatchPlan so the dispatch +// loop can unpack any remainder back into ResumeShards if lastErr +// trips mid-dispatch. +type resumeBatch struct { + payload BatchPayload + noProgress map[int32]time.Duration +} + +// dispatchResumeBatches turns the prior cycle's drain payload into a +// fresh round of resume activities, packed across shards up to +// BatchSize per batch. Each shard appears at most once across the +// payload (shardInFlight enforces that only one batch holds a shard +// at a time, and a shard only lands in a drain return while its +// owning batch still has unverified execs on it), so per-shard +// contributions are taken whole and no MaxExecsPerShard cap applies — +// resume carries no inject load so the per-shard blast-radius the +// streaming packer guards against doesn't exist here. +// +// Plans every batch up front, then dispatches one at a time. If +// lastErr latches mid-dispatch, the remaining planned batches are +// unpacked back into s.params.ResumeShards so the recovery bundle +// (and the next CAN cycle) sees them — without the unpack step, a +// failing first resume batch would silently strand all subsequent +// resume entries. +func (s *shardedWorkflowState) dispatchResumeBatches(ctx workflow.Context) { + if len(s.params.ResumeShards) == 0 { + return + } + entries := make([]ResumeShard, 0, len(s.params.ResumeShards)) + for _, rs := range s.params.ResumeShards { + if runCount(rs.Execs) == 0 { + continue + } + entries = append(entries, rs) + } + // We've taken ownership of these entries — anything not + // dispatched gets restored below. + s.params.ResumeShards = nil + slices.SortFunc(entries, func(a, b ResumeShard) int { + return int(a.Shard - b.Shard) + }) + + batches := s.packResumeBatchPlan(entries) + for i, batch := range batches { + if s.lastErr != nil { + s.params.ResumeShards = unpackResumeBatches(batches[i:]) + return + } + // Block until a dispatch slot is free so resume payloads + // can't overshoot ConcurrentBatchCount on cycles that + // carried many shards across CAN. + s.waitForDispatchSlot(ctx) + if s.lastErr != nil { + s.params.ResumeShards = unpackResumeBatches(batches[i:]) + return + } + for sh := range batch.payload { + s.shardInFlight[sh] = true + } + s.spawnBatch(ctx, batch.payload, true, batch.noProgress) + } +} + +// packResumeBatchPlan groups ResumeShard entries into batches sized +// at or below BatchSize. Entries are taken whole — shardInFlight only +// admits one batch per shard at a time, so a single shard's payload +// can't be split. +func (s *shardedWorkflowState) packResumeBatchPlan(entries []ResumeShard) []resumeBatch { + var batches []resumeBatch + current := resumeBatch{payload: BatchPayload{}, noProgress: map[int32]time.Duration{}} + packed := 0 + for _, rs := range entries { + rsCount := runCount(rs.Execs) + if packed+rsCount > s.params.BatchSize && packed > 0 { + batches = append(batches, current) + current = resumeBatch{payload: BatchPayload{}, noProgress: map[int32]time.Duration{}} + packed = 0 + } + current.payload[rs.Shard] = rs.Execs + current.noProgress[rs.Shard] = rs.NoProgressDuration + packed += rsCount + } + if packed > 0 { + batches = append(batches, current) + } + return batches +} + +// unpackResumeBatches reverses packResumeBatchPlan, turning planned +// batches back into a flat ResumeShard slice. Used when the dispatch +// loop aborts on lastErr so the undispatched remainder can be carried +// into the recovery bundle / next CAN cycle. +func unpackResumeBatches(batches []resumeBatch) []ResumeShard { + var out []ResumeShard + for _, b := range batches { + for sh, execs := range b.payload { + out = append(out, ResumeShard{ + Shard: sh, + Execs: execs, + NoProgressDuration: b.noProgress[sh], + }) + } + } + return out +} + +// runCount sums runs across BIDs in a single shard's payload entry. +func runCount(byBID map[string][]RunEntry) int { + n := 0 + for _, runs := range byBID { + n += len(runs) + } + return n +} + +// drainBuckets blocks until buckets are empty (success) or lastErr +// trips (failure). Each pass packs everything currently dispatchable, +// then awaits any change in pendingDispatches + shardInFlight so the +// next pass can attempt shards just freed by signal-release. +func (s *shardedWorkflowState) drainBuckets(ctx workflow.Context) { + for { + if s.lastErr != nil { + return + } + for s.tryPackStreaming(ctx, true) { //nolint:revive + } + if s.bucketsEmpty() || s.lastErr != nil { + return + } + currentPending := s.pendingDispatches + if currentPending == 0 { + s.failDrainBucketsStuck() + return + } + _ = workflow.Await(ctx, s.drainBucketsAwaitPredicate(currentPending)) + } +} + +// failDrainBucketsStuck sets lastErr when buckets are non-empty but no +// batches are in flight — the shard-claim bookkeeping is corrupted, and +// returning silently would proceed to CAN with execs that were never +// dispatched (silent data loss). Failing forces lastErr to propagate +// through run(). +func (s *shardedWorkflowState) failDrainBucketsStuck() { + remaining := 0 + for _, n := range s.bucketCounts { + remaining += n + } + s.setLastErr(temporal.NewNonRetryableApplicationError( + fmt.Sprintf("drainBuckets: %d execs in buckets but no batches in flight (shard-claim bookkeeping corrupted)", remaining), + "DrainBucketsStuck", nil)) +} + +// drainBucketsAwaitPredicate returns true when the drainBuckets loop +// should wake up: lastErr tripped, a dispatch slot just freed, or a +// new free shard is ready to pack. A "free shard" wake-up only counts +// when there's also a dispatch slot to use it, otherwise the outer +// loop would busy-spin on tryPackStreaming returning false against the +// in-flight cap. +func (s *shardedWorkflowState) drainBucketsAwaitPredicate(currentPending int) func() bool { + return func() bool { + if s.lastErr != nil { + return true + } + if s.pendingDispatches < currentPending { + return true + } + if !s.dispatchSlotAvailable() { + return false + } + for sh, n := range s.bucketCounts { + if n > 0 && !s.shardInFlight[sh] { + return true + } + } + return false + } +} + +// drainForCAN cancels every in-flight batch and waits for them to +// return AND for the ReleaseShards signal channel to be drained. +// Activities honour cancellation by entering drain mode and returning +// a result whose InFlight carries their still-unverified execs; +// spawnBatch appends those entries to s.drainPayload. The signal +// channel drain is so a final ReleaseShards fired by an activity just +// before it returns doesn't get stranded mid-flight, which would +// leave shardInFlight set for shards the activity already considers +// complete. +// +// Channel.Len() is safe here because handleReleaseSignals has no +// yield points between Receive and the next blocking Receive, so +// Len() == 0 observed across an Await re-evaluation means every +// delivered signal has been processed. +// +// No explicit time bound: a well-behaved activity returns within +// req.DrainGrace (15s default) plus a small idle-cost slack; a +// misbehaved one is bounded by the activity's HeartbeatTimeout (1m). +// In practice this Await unblocks well under a minute. +func (s *shardedWorkflowState) drainForCAN(ctx workflow.Context) { + if s.pendingDispatches == 0 { + return + } + s.cancelActivities() + releaseCh := workflow.GetSignalChannel(ctx, releaseShardsSignalName) + // Wait unconditionally for pendingDispatches to drain — lastErr may + // already be set on entry, but drainPayload, batchExecs, and status + // recovery fields only finalise once every in-flight goroutine has + // returned. + _ = workflow.Await(ctx, func() bool { + return s.pendingDispatches == 0 && releaseCh.Len() == 0 + }) +} + +// spawnBatch dispatches one batch on a new workflow.Go coroutine. +// Callers must have already marked every shard appearing as a +// top-level key in payload as shardInFlight (the "claim") so the +// packer can see them as busy while picking subsequent batches. The +// activity is run on s.activityCtx so drainForCAN can cancel every +// in-flight batch with a single call. +func (s *shardedWorkflowState) spawnBatch( + ctx workflow.Context, + payload BatchPayload, + resume bool, + noProgressByShard map[int32]time.Duration, +) { + if payload.totalRuns() == 0 { + return + } + s.nextBatchID++ + batchID := s.nextBatchID + + req := &shardedBatchReq{ + BatchID: batchID, + Namespace: s.params.Namespace, + NamespaceID: s.namespaceID, + Executions: payload, + TargetClusterName: s.params.TargetClusterName, + Resume: resume, + DisableVerification: s.params.DisableVerification, + NoProgressByShard: noProgressByShard, + PerBatchGenerateRPS: s.params.PerBatchGenerateRPS, + ShardNoProgress: s.params.ShardNoProgress, + DrainGrace: s.params.DrainGrace, + IdleShardCost: s.params.IdleShardCost, + } + + held := make(map[int32]bool, len(payload)) + for sh := range payload { + held[sh] = true + } + s.heldByBatch[batchID] = held + s.batchExecs[batchID] = payload + + s.pendingDispatches++ + workflow.Go(ctx, func(coroCtx workflow.Context) { + defer func() { + s.pendingDispatches-- + // Clear any shards we still hold — signal-released + // shards have already been cleared from shardInFlight + // by handleReleaseSignals and may by now belong to a + // subsequent batch's claim. + for sh := range s.heldByBatch[batchID] { + delete(s.shardInFlight, sh) + } + delete(s.heldByBatch, batchID) + }() + result, err := s.replicateBatch(coroCtx, s.activityCtx, req) + if err != nil { + if temporal.IsCanceledError(err) { + // Cancel-before-start: the activity body never ran, + // so no result is available. Leaving batchExecs[batchID] + // intact lets the CAN-end recovery path re-bucket the + // execs as fresh inject+verify work next cycle. + return + } + // Activity errored after partial verify — the SDK discards + // the result on failure, but wrapBatchVerifyError on the + // activity side carries the partial doneCount through as + // ApplicationError details. Fold it into the running count + // so ReplicatedWorkflowCount reflects work actually done. + s.recordVerified(coroCtx, extractVerifiedCountFromError(err)) + s.setLastErr(err) + return + } + // Activity body ran and returned cleanly — either a + // clean completion (empty InFlight) or a drained + // CAN-cancel (InFlight carries the still-unverified + // execs). CompletedShards is informational; the + // defer above clears heldByBatch + shardInFlight + // either way. + if len(result.InFlight) > 0 { + s.drainPayload = append(s.drainPayload, result.InFlight...) + } + s.recordVerified(coroCtx, result.VerifiedCount) + delete(s.batchExecs, batchID) + }) +} + +// tryPackStreaming attempts to pack and dispatch one batch from +// s.buckets. Returns true if a batch was dispatched. +// +// No per-shard or total-bucket threshold: as soon as any free shard +// has any execs and a dispatch slot is open, a batch fires. Safety +// is enforced outside the packer — MaxExecsPerShard caps a single +// shard's contribution to a batch (so one hot shard can't dominate), +// ConcurrentBatchCount caps in-flight batches, and PerBatchGenerateRPS +// caps the per-batch source RPS. Within those bounds the packer's +// sole job is to make progress every chance it gets. +// +// See shardIDsByPackPriority for the relax-mode ordering rationale. +func (s *shardedWorkflowState) tryPackStreaming(ctx workflow.Context, relax bool) bool { + if s.lastErr != nil || s.params.BatchSize <= 0 || s.params.MaxExecsPerShard <= 0 { + return false + } + if !s.dispatchSlotAvailable() { + return false + } + shardIDs := s.shardIDsByPackPriority(relax) + if len(shardIDs) == 0 { + return false + } + + payload := BatchPayload{} + packed := 0 + for _, sh := range shardIDs { + room := s.params.BatchSize - packed + if room <= 0 { + break + } + take := min(s.params.MaxExecsPerShard, s.bucketCounts[sh], room) + if take == 0 { + continue + } + payload[sh] = s.takeFromBucket(sh, take) + packed += take + s.shardInFlight[sh] = true + } + if packed == 0 { + return false + } + s.spawnBatch(ctx, payload, false, nil) + return true +} + +// bucketsEmpty reports whether every shard's bucket is empty. Reads +// from the sidecar count map so it's O(#shards), not O(#runs). +func (s *shardedWorkflowState) bucketsEmpty() bool { + for _, n := range s.bucketCounts { + if n > 0 { + return false + } + } + return true +} + +// shardIDsByPackPriority returns free, non-empty shard IDs in the +// order the packer should consider them. Deterministic across +// replays: ordering is derived from workflow state (bucketCounts) +// with shard ID as a stable tiebreaker. +// +// relax=false (streaming): fullest first. Packer naturally produces +// large, predictable batches when work is plentiful and small ones +// when it isn't — either way it ships rather than waiting. +// +// relax=true (drain): hot shards (count > MaxExecsPerShard) first, +// fullest within hot. These need >1 round trip to drain, so total +// drain wall-clock is bounded by the heaviest shard; starting their +// pipelines first is the dominant lever. After all hot shards are +// claimed, remaining batch capacity fills from smallest cold buckets +// (ascending count) so light shards clear out quickly — nothing is +// arriving in drain, so waiting for cold buckets to grow is wasted +// wall-clock. +func (s *shardedWorkflowState) shardIDsByPackPriority(relax bool) []int32 { + out := make([]int32, 0, len(s.bucketCounts)) + for sh, n := range s.bucketCounts { + if n == 0 || s.shardInFlight[sh] { + continue + } + out = append(out, sh) + } + if relax { + maxPerShard := s.params.MaxExecsPerShard + slices.SortFunc(out, func(a, b int32) int { + aHot, bHot := s.bucketCounts[a] > maxPerShard, s.bucketCounts[b] > maxPerShard + switch { + case aHot && !bHot: + return -1 + case !aHot && bHot: + return 1 + case aHot && bHot: + if d := s.bucketCounts[b] - s.bucketCounts[a]; d != 0 { + return d + } + default: + if d := s.bucketCounts[a] - s.bucketCounts[b]; d != 0 { + return d + } + } + return int(a - b) + }) + return out + } + slices.SortFunc(out, func(a, b int32) int { + if d := s.bucketCounts[b] - s.bucketCounts[a]; d != 0 { + return d + } + return int(a - b) + }) + return out +} diff --git a/service/worker/migration/sharded_workflow_test.go b/service/worker/migration/sharded_workflow_test.go new file mode 100644 index 00000000000..bb7e852b92c --- /dev/null +++ b/service/worker/migration/sharded_workflow_test.go @@ -0,0 +1,931 @@ +package migration + +import ( + "context" + "errors" + "fmt" + "strconv" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/sdk/activity" + "go.temporal.io/sdk/converter" + "go.temporal.io/sdk/temporal" + "go.temporal.io/sdk/testsuite" + "go.temporal.io/sdk/workflow" + "go.temporal.io/server/common" +) + +// testNamespaceID is what metadataResponseFor returns. Tests pass it +// through to bidsForShards so the BIDs they hand to makeExecs hash +// to the same shards the workflow will compute during the page loop. +const testNamespaceID = "test-ns-id" + +// ---- Test setup helpers ---- + +// bidsForShards returns, for each shard the hash actually populates +// under common.WorkflowIDToHistoryShard(namespaceID, bid, totalShards), +// a slice of `perShard` BusinessIDs. Brute-force search over candidate +// strings — sufficient for the small shard counts the tests use. +func bidsForShards(namespaceID string, totalShards int32, perShard int) map[int32][]string { + out := make(map[int32][]string, totalShards) + for i := 0; ; i++ { + bid := fmt.Sprintf("wf-%d", i) + sh := common.WorkflowIDToHistoryShard(namespaceID, bid, totalShards) + if len(out[sh]) < perShard { + out[sh] = append(out[sh], bid) + } + if int32(len(out)) == totalShards { + done := true + for sh := range out { + if len(out[sh]) < perShard { + done = false + break + } + } + if done { + return out + } + } + } +} + +// makeExecs builds a slice of ExecutionInfos engineered to hash across +// `shards` distinct shards (`perShard` execs per shard) under the test +// namespace ID + shard count. The workflow's page loop computes the +// destination shard itself via common.WorkflowIDToHistoryShard. +func makeExecs(shards int32, perShard int) []*ExecutionInfo { + bids := bidsForShards(testNamespaceID, shards, perShard) + var execs []*ExecutionInfo + idx := 0 + for sh := range bids { + for _, bid := range bids[sh] { + execs = append(execs, &ExecutionInfo{ + BusinessID: bid, + RunID: "run-" + strconv.Itoa(idx), + }) + idx++ + } + } + return execs +} + +// pageThrough returns a function suitable for OnActivity("ListWorkflows") +// that paginates `all` into pages of `pageSize` execs each. The workflow +// computes each exec's destination shard itself, so callers don't need +// to set anything shard-related on the ExecutionInfos. +func pageThrough(all []*ExecutionInfo, pageSize int) func(context.Context, *workflowservice.ListWorkflowExecutionsRequest) (*listWorkflowsResponse, error) { + return func(_ context.Context, req *workflowservice.ListWorkflowExecutionsRequest) (*listWorkflowsResponse, error) { + start := 0 + if len(req.NextPageToken) > 0 { + start, _ = strconv.Atoi(string(req.NextPageToken)) + } + end := min(start+pageSize, len(all)) + var nextToken []byte + if end < len(all) { + nextToken = []byte(strconv.Itoa(end)) + } + return &listWorkflowsResponse{ + Executions: all[start:end], + NextPageToken: nextToken, + }, nil + } +} + +// metadataResponseFor returns a function suitable for +// OnActivity("GetMetadata") that yields a fixed shard count + ns ID. +func metadataResponseFor(shardCount int32) func(context.Context, MetadataRequest) (*MetadataResponse, error) { + return func(_ context.Context, _ MetadataRequest) (*MetadataResponse, error) { + return &MetadataResponse{ + ShardCount: shardCount, + NamespaceID: testNamespaceID, + }, nil + } +} + +// registerShardedScaffolding registers the GetMetadata + CountWorkflow +// stubs every sharded test needs before the page loop runs, plus the +// task-queue-user-data child workflow + its activity so the parent's +// terminal Await on Done resolves. +func registerShardedScaffolding(env *testsuite.TestWorkflowEnvironment, shardCount int32) { + registerShardedScaffoldingWithSeed(env, shardCount, func(_ context.Context, _ TaskQueueUserDataReplicationParamsWithNamespace) error { + return nil + }) +} + +// registerShardedScaffoldingWithSeed is like registerShardedScaffolding +// but lets the caller supply the SeedReplicationQueueWithUserDataEntries +// activity body — needed for tests that exercise the seed-failure path. +func registerShardedScaffoldingWithSeed( + env *testsuite.TestWorkflowEnvironment, + shardCount int32, + seed func(context.Context, TaskQueueUserDataReplicationParamsWithNamespace) error, +) { + env.RegisterActivityWithOptions(metadataResponseFor(shardCount), activity.RegisterOptions{Name: "GetMetadata"}) + env.RegisterActivityWithOptions(func(_ context.Context, _ DescribeTargetClusterRequest) (*DescribeTargetClusterResponse, error) { + return &DescribeTargetClusterResponse{ShardCount: shardCount}, nil + }, activity.RegisterOptions{Name: "DescribeTargetCluster"}) + env.RegisterActivityWithOptions(func(_ context.Context, _ *workflowservice.CountWorkflowExecutionsRequest) (*countWorkflowResponse, error) { + return &countWorkflowResponse{WorkflowCount: 0}, nil + }, activity.RegisterOptions{Name: "CountWorkflow"}) + env.RegisterWorkflowWithOptions(ForceTaskQueueUserDataReplicationWorkflow, workflow.RegisterOptions{Name: forceTaskQueueUserDataReplicationWorkflow}) + env.RegisterActivityWithOptions(seed, activity.RegisterOptions{Name: "SeedReplicationQueueWithUserDataEntries"}) +} + +// ---- Tests ---- + +// TestSharded_HappyPath_SingleCycle: a small workload exhausts in one +// cycle, every batch returns clean completion, no CAN, no resume. +func TestSharded_HappyPath_SingleCycle(t *testing.T) { + suite := &testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(ShardedForceReplicationWorkflow) + + execs := makeExecs(4, 5) // 20 execs across 4 shards + registerShardedScaffolding(env, 4) + env.RegisterActivityWithOptions(pageThrough(execs, 1000), activity.RegisterOptions{Name: "ListWorkflows"}) + + var ( + mu sync.Mutex + batchesSeen int + execsSeen int + shardsSeen = map[int32]struct{}{} + ) + env.RegisterActivityWithOptions(func(_ context.Context, req *shardedBatchReq) (replicateBatchResult, error) { + mu.Lock() + batchesSeen++ + execsSeen += req.Executions.totalRuns() + for _, sh := range req.Executions.sortedShards() { + shardsSeen[sh] = struct{}{} + } + mu.Unlock() + return replicateBatchResult{}, nil + }, activity.RegisterOptions{Name: "ReplicateBatch"}) + + env.ExecuteWorkflow(ShardedForceReplicationWorkflow, ShardedForceReplicationParams{ + Namespace: "test-ns", + TargetClusterName: "remote_cluster", + }) + + require.True(t, env.IsWorkflowCompleted(), "workflow should complete") + require.NoError(t, env.GetWorkflowError(), "workflow should succeed") + // 20 execs across 4 shards, BatchSize=100, MaxExecsPerShard=50, single + // page — every exec fits in one batch. + require.Equal(t, 1, batchesSeen, "all execs should pack into a single batch") + require.Equal(t, len(execs), execsSeen, "every exec should reach the activity") + require.Len(t, shardsSeen, 4, "every shard should be represented") +} + +// TestSharded_ResumeShards_Packed: a non-empty ResumeShards in params +// gets packed into multi-shard batches up to BatchSize. Asserts that +// no resume batch exceeds BatchSize and the per-shard contributions +// match the input. +func TestSharded_ResumeShards_Packed(t *testing.T) { + suite := &testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(ShardedForceReplicationWorkflow) + registerShardedScaffolding(env, 8) + env.RegisterActivityWithOptions(pageThrough(nil, 1000), activity.RegisterOptions{Name: "ListWorkflows"}) + + // 8 shards, 15 unverified execs each = 120 total. With BatchSize=100, + // we expect 2 batches: e.g., [0..5] (90 execs) + [5+, 6, 7] (30) or + // similar — depends on greedy packing. + resumeShards := make([]ResumeShard, 8) + for s := range 8 { + resumeShards[s] = ResumeShard{ + Shard: int32(s), + Execs: makeExecsForShard(int32(s), 15), + } + } + + var ( + mu sync.Mutex + batches [][]int32 + ) + env.RegisterActivityWithOptions(func(_ context.Context, req *shardedBatchReq) (replicateBatchResult, error) { + require.True(t, req.Resume, "all resume-dispatched batches must have Resume=true") + require.LessOrEqual(t, req.Executions.totalRuns(), 100, "batch must not exceed BatchSize") + mu.Lock() + batches = append(batches, req.Executions.sortedShards()) + mu.Unlock() + return replicateBatchResult{}, nil + }, activity.RegisterOptions{Name: "ReplicateBatch"}) + + env.ExecuteWorkflow(ShardedForceReplicationWorkflow, ShardedForceReplicationParams{ + Namespace: "test-ns", + TargetClusterName: "remote_cluster", + ResumeShards: resumeShards, + }) + + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + + // Confirm every shard was covered exactly once across all batches. + covered := map[int32]int{} + for _, b := range batches { + for _, sh := range b { + covered[sh]++ + } + } + for s := range int32(8) { + require.Equal(t, 1, covered[s], "shard %d should be covered exactly once", s) + } + require.GreaterOrEqual(t, len(batches), 2, "should pack into at least 2 batches given 120 execs / BatchSize=100") +} + +// TestSharded_ReleaseShards_FreesShardForReuse: an activity sends a +// ReleaseShards signal mid-flight. The workflow must clear the shard +// from shardInFlight so the packer can dispatch a fresh batch +// targeting that shard *while the original activity is still running* +// — the slot in ConcurrentBatchCount is still claimed by batch 1, so +// batch 2 can only dispatch if signal-release worked. +// +// ConcurrentBatchCount=2 is explicit: defaultConcurrentBatchCount(2)=1 +// would gate batch 2 on the dispatch slot regardless of shard state, so +// the test couldn't distinguish "release-from-flight" from +// "batch 1 returned and freed the slot". +func TestSharded_ReleaseShards_FreesShardForReuse(t *testing.T) { + suite := &testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(ShardedForceReplicationWorkflow) + registerShardedScaffolding(env, 2) + + // Two pages of execs on the same shards. The second batch can dispatch + // only when batch 1 releases its shards. + phase1 := makeExecs(2, 10) + phase2 := makeExecs(2, 10) + all := append(append([]*ExecutionInfo{}, phase1...), phase2...) + env.RegisterActivityWithOptions(pageThrough(all, 20), activity.RegisterOptions{Name: "ListWorkflows"}) + + var ( + mu sync.Mutex + batches []*shardedBatchReq + secondStarted = make(chan struct{}) + secondOnce sync.Once + releaseObserved atomic.Bool + ) + env.RegisterActivityWithOptions(func(_ context.Context, req *shardedBatchReq) (replicateBatchResult, error) { + mu.Lock() + batches = append(batches, req) + idx := len(batches) + mu.Unlock() + + switch idx { + case 1: + // Signal-release, then block here until batch 2 actually + // dispatches. If signal-release wires through, the workflow + // dispatches batch 2 concurrently; if it doesn't, batch 2 + // can't run until this activity returns (the safety timeout + // below). + env.SignalWorkflow("ReleaseShards", releaseShardsPayload{ + BatchID: req.BatchID, + Shards: req.Executions.sortedShards(), + }) + select { + case <-secondStarted: + releaseObserved.Store(true) + case <-time.After(30 * time.Second): + // Safety release so the test fails the assertion rather + // than hanging indefinitely. Generous bound because CI + // can be slow; the happy path returns immediately. + } + case 2: + secondOnce.Do(func() { close(secondStarted) }) + default: + } + return replicateBatchResult{}, nil + }, activity.RegisterOptions{Name: "ReplicateBatch"}) + + env.ExecuteWorkflow(ShardedForceReplicationWorkflow, ShardedForceReplicationParams{ + Namespace: "test-ns", + TargetClusterName: "remote_cluster", + ConcurrentBatchCount: 2, + }) + + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + require.Len(t, batches, 2, "expected two batches across the two pages") + require.True(t, releaseObserved.Load(), + "signal-release should let batch 2 dispatch while batch 1 still holds its dispatch slot") +} + +// TestSharded_ShardNoProgress_FailsWorkflow: activity returns a +// non-retryable ShardNoProgress error → workflow fails with that +// error, no CAN. +func TestSharded_ShardNoProgress_FailsWorkflow(t *testing.T) { + suite := &testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(ShardedForceReplicationWorkflow) + registerShardedScaffolding(env, 2) + env.RegisterActivityWithOptions(pageThrough(makeExecs(2, 5), 1000), activity.RegisterOptions{Name: "ListWorkflows"}) + + env.RegisterActivityWithOptions(func(_ context.Context, req *shardedBatchReq) (replicateBatchResult, error) { + shards := req.Executions.sortedShards() + return replicateBatchResult{}, temporal.NewNonRetryableApplicationError( + "shard "+strconv.Itoa(int(shards[0]))+" stuck", "ShardNoProgress", nil) + }, activity.RegisterOptions{Name: "ReplicateBatch"}) + + env.ExecuteWorkflow(ShardedForceReplicationWorkflow, ShardedForceReplicationParams{ + Namespace: "test-ns", + TargetClusterName: "remote_cluster", + }) + + require.True(t, env.IsWorkflowCompleted()) + err := env.GetWorkflowError() + require.Error(t, err) + var appErr *temporal.ApplicationError + require.ErrorAs(t, err, &appErr) + require.Equal(t, "ShardNoProgress", appErr.Type()) +} + +// TestSharded_DrainResult_FromActivityResult_FeedsCANCarryover: a +// non-empty InFlight in the activity's replicateBatchResult must +// populate drainPayload and end up as ResumeShards in the CAN +// carry-over. +// +// Tests the workflow plumbing only. In production, an activity +// returns InFlight after entering drain mode and grace-expiring; here +// we exercise the same code path by returning InFlight from a +// non-cancelled run, because the testsuite delivers cancellation as +// a CanceledError without preserving the activity's returned result. +// The dispatch coroutine's err == nil branch is what we're testing — +// it doesn't care whether the activity was cancelled or not, only +// whether the returned result has InFlight to fold into drainPayload. +func TestSharded_DrainResult_FromActivityResult_FeedsCANCarryover(t *testing.T) { + suite := &testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(ShardedForceReplicationWorkflow) + registerShardedScaffolding(env, 10) + + // Page 1 returns a multi-shard population so the streaming packer + // has something to dispatch under the pinned BatchSize / + // MaxExecsPerShard the test sets below. The activity flips + // CAN-suggested from inside its body before returning, so by the + // time it has handed back its InFlight the workflow is committed + // to CAN — but without going through cancel, which the testsuite + // delivers as a CanceledError regardless of any result the + // activity returned. + pageExecs := makeExecs(10, 10) + // Drained exec mirrors a real input row so the simulated drain + // payload would be a valid response from a real activity. drainedBID + // is the first exec; drainedShard is the shard the workflow will + // hash it to. + drainedBID := pageExecs[0].BusinessID + const drainedRunID = "run-drained" + drainedShard := common.WorkflowIDToHistoryShard(testNamespaceID, drainedBID, 10) + env.RegisterActivityWithOptions(func(_ context.Context, _ *workflowservice.ListWorkflowExecutionsRequest) (*listWorkflowsResponse, error) { + return &listWorkflowsResponse{ + Executions: pageExecs, + NextPageToken: []byte("more"), + }, nil + }, activity.RegisterOptions{Name: "ListWorkflows"}) + + env.RegisterActivityWithOptions(func(_ context.Context, _ *shardedBatchReq) (replicateBatchResult, error) { + env.SetContinueAsNewSuggested(true) + return replicateBatchResult{ + InFlight: []ResumeShard{{ + Shard: drainedShard, + Execs: map[string][]RunEntry{drainedBID: {{RunID: drainedRunID}}}, + NoProgressDuration: 42 * time.Second, + }}, + }, nil + }, activity.RegisterOptions{Name: "ReplicateBatch"}) + + env.ExecuteWorkflow(ShardedForceReplicationWorkflow, ShardedForceReplicationParams{ + Namespace: "test-ns", + TargetClusterName: "remote_cluster", + BatchSize: 100, + MaxExecsPerShard: 10, + }) + + require.True(t, env.IsWorkflowCompleted()) + err := env.GetWorkflowError() + require.Error(t, err, "workflow should CAN, not return success") + + var canErr *workflow.ContinueAsNewError + require.ErrorAs(t, err, &canErr, "error should be ContinueAsNewError") + + var nextParams ShardedForceReplicationParams + require.NoError(t, converter.GetDefaultDataConverter().FromPayloads(canErr.Input, &nextParams)) + require.NotEmpty(t, nextParams.ResumeShards, "InFlight from a returned activity should appear in resume payload") + require.Equal(t, drainedShard, nextParams.ResumeShards[0].Shard) + runs, ok := nextParams.ResumeShards[0].Execs[drainedBID] + require.True(t, ok, "drained business ID should appear in nested resume payload") + require.Equal(t, []RunEntry{{RunID: drainedRunID}}, runs) + require.Equal(t, 42*time.Second, nextParams.ResumeShards[0].NoProgressDuration) +} + +// TestSharded_CancelBeforeStart_NoLostExecs pins down recovery when +// an activity is dispatched and the workflow CANs before the activity +// body runs: the activity returns CanceledError with no result, and +// the recovery path re-buckets the input execs into RecoveredBuckets +// so the next cycle dispatches them as fresh inject+verify batches. +func TestSharded_CancelBeforeStart_NoLostExecs(t *testing.T) { + suite := &testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(ShardedForceReplicationWorkflow) + registerShardedScaffolding(env, 4) + + execs := makeExecs(4, 5) // 20 execs across 4 shards + pageServed := false + env.RegisterActivityWithOptions(func(_ context.Context, _ *workflowservice.ListWorkflowExecutionsRequest) (*listWorkflowsResponse, error) { + if pageServed { + return &listWorkflowsResponse{}, nil + } + pageServed = true + // Trigger CAN as soon as this page is served so drainForCAN + // runs before any dispatched batches can complete. + env.SetContinueAsNewSuggested(true) + return &listWorkflowsResponse{ + Executions: execs, + NextPageToken: []byte("more"), + }, nil + }, activity.RegisterOptions{Name: "ListWorkflows"}) + + // Activity that responds to ctx cancellation by returning a + // CanceledError with no result — simulating the + // "cancelled before any work done" path. + var ( + batchCount atomic.Int32 + cancelledIDs []int64 + muIDs sync.Mutex + ) + env.RegisterActivityWithOptions(func(ctx context.Context, req *shardedBatchReq) (replicateBatchResult, error) { + batchCount.Add(1) + <-ctx.Done() + muIDs.Lock() + cancelledIDs = append(cancelledIDs, req.BatchID) + muIDs.Unlock() + return replicateBatchResult{}, temporal.NewCanceledError() + }, activity.RegisterOptions{Name: "ReplicateBatch"}) + + env.ExecuteWorkflow(ShardedForceReplicationWorkflow, ShardedForceReplicationParams{ + Namespace: "test-ns", + TargetClusterName: "remote_cluster", + }) + + require.True(t, env.IsWorkflowCompleted()) + err := env.GetWorkflowError() + require.Error(t, err) + var canErr *workflow.ContinueAsNewError + require.ErrorAs(t, err, &canErr, "expected CAN, got %v", err) + + var nextParams ShardedForceReplicationParams + require.NoError(t, converter.GetDefaultDataConverter().FromPayloads(canErr.Input, &nextParams)) + + // Count execs the next cycle would re-dispatch. The activity + // body never ran in this race, so execs land in RecoveredBuckets + // (fresh inject+verify) rather than ResumeShards (verify-only). + recovered := nextParams.RecoveredBuckets.totalRuns() + + t.Logf("dispatched %d batches, cancelled %d, recovered execs %d (expected %d)", + batchCount.Load(), len(cancelledIDs), recovered, len(execs)) + + require.Equal(t, len(execs), recovered, "every dispatched exec should land in RecoveredBuckets when its activity is cancelled before it can run") + require.Empty(t, nextParams.ResumeShards, "no ResumeShards — activity never ran, never injected, so no resume work") +} + +// TestSharded_DisableVerification_NoVerifiedCount: with verification +// disabled the workflow runs inject-only batches, completes +// successfully, and the status query reports ReplicatedWorkflowCount=0 +// because no verification ran. +func TestSharded_DisableVerification_NoVerifiedCount(t *testing.T) { + suite := &testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(ShardedForceReplicationWorkflow) + + execs := makeExecs(4, 5) + registerShardedScaffolding(env, 4) + env.RegisterActivityWithOptions(pageThrough(execs, 1000), activity.RegisterOptions{Name: "ListWorkflows"}) + + var sawDisable atomic.Bool + env.RegisterActivityWithOptions(func(_ context.Context, req *shardedBatchReq) (replicateBatchResult, error) { + if req.DisableVerification { + sawDisable.Store(true) + } + return replicateBatchResult{ + CompletedShards: req.Executions.sortedShards(), + VerifiedCount: 0, + }, nil + }, activity.RegisterOptions{Name: "ReplicateBatch"}) + + env.ExecuteWorkflow(ShardedForceReplicationWorkflow, ShardedForceReplicationParams{ + Namespace: "test-ns", + TargetClusterName: "remote_cluster", + DisableVerification: true, + }) + + require.True(t, env.IsWorkflowCompleted(), "workflow should complete") + require.NoError(t, env.GetWorkflowError(), "workflow should succeed") + require.True(t, sawDisable.Load(), "activity req should carry DisableVerification=true") + + envValue, err := env.QueryWorkflow(forceReplicationStatusQueryType) + require.NoError(t, err) + var status ForceReplicationStatus + require.NoError(t, envValue.Get(&status)) + require.Equal(t, int64(0), status.ReplicatedWorkflowCount, "no verification ran so verified count must stay 0") +} + +// TestSharded_InvalidInput: validateShardedForceReplicationParams +// rejects an empty Namespace and a missing TargetClusterName. Mirrors +// the existing force-replication TestInvalidInput. +func TestSharded_InvalidInput(t *testing.T) { + for _, tc := range []struct { + name string + params ShardedForceReplicationParams + }{ + { + name: "empty namespace", + params: ShardedForceReplicationParams{}, + }, + { + name: "missing target cluster name", + params: ShardedForceReplicationParams{ + Namespace: "test-ns", + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + suite := &testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(ShardedForceReplicationWorkflow) + env.ExecuteWorkflow(ShardedForceReplicationWorkflow, tc.params) + + require.True(t, env.IsWorkflowCompleted()) + err := env.GetWorkflowError() + require.Error(t, err) + require.Contains(t, err.Error(), "InvalidArgument") + }) + } +} + +// TestSharded_ListWorkflowsError: a hard failure from ListWorkflows +// propagates out as the workflow error. Mirrors the existing +// force-replication TestListWorkflowsError. +func TestSharded_ListWorkflowsError(t *testing.T) { + suite := &testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(ShardedForceReplicationWorkflow) + registerShardedScaffolding(env, 2) + + env.RegisterActivityWithOptions(func(_ context.Context, _ *workflowservice.ListWorkflowExecutionsRequest) (*listWorkflowsResponse, error) { + return nil, temporal.NewNonRetryableApplicationError("mock listWorkflows error", "ListFailed", nil) + }, activity.RegisterOptions{Name: "ListWorkflows"}) + + // ReplicateBatch should never be invoked because listing fails up + // front. Register a fail-loud stub so we notice if the workflow + // ever dispatches a batch. + env.RegisterActivityWithOptions(func(_ context.Context, _ *shardedBatchReq) (replicateBatchResult, error) { + t.Fatal("ReplicateBatch must not be called when listing fails") + return replicateBatchResult{}, nil + }, activity.RegisterOptions{Name: "ReplicateBatch"}) + + env.ExecuteWorkflow(ShardedForceReplicationWorkflow, ShardedForceReplicationParams{ + Namespace: "test-ns", + TargetClusterName: "remote_cluster", + }) + + require.True(t, env.IsWorkflowCompleted()) + err := env.GetWorkflowError() + require.Error(t, err) + require.Contains(t, err.Error(), "mock listWorkflows error") +} + +// TestSharded_ReplicateBatchRetryableError: when ReplicateBatch returns +// a retryable error, the workflow exhausts its 3-attempt retry policy +// and surfaces the error as lastErr. Mirrors the existing +// TestGenerateReplicationTaskRetryableError. +func TestSharded_ReplicateBatchRetryableError(t *testing.T) { + suite := &testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(ShardedForceReplicationWorkflow) + registerShardedScaffolding(env, 2) + env.RegisterActivityWithOptions(pageThrough(makeExecs(2, 5), 1000), activity.RegisterOptions{Name: "ListWorkflows"}) + + var attempts atomic.Int32 + env.RegisterActivityWithOptions(func(_ context.Context, _ *shardedBatchReq) (replicateBatchResult, error) { + attempts.Add(1) + return replicateBatchResult{}, temporal.NewApplicationError("transient backend error", "Transient") + }, activity.RegisterOptions{Name: "ReplicateBatch"}) + + env.ExecuteWorkflow(ShardedForceReplicationWorkflow, ShardedForceReplicationParams{ + Namespace: "test-ns", + TargetClusterName: "remote_cluster", + }) + + require.True(t, env.IsWorkflowCompleted()) + err := env.GetWorkflowError() + require.Error(t, err) + require.Contains(t, err.Error(), "transient backend error") + // MaximumAttempts: 3 in spawnBatch's activity options — assert at + // least 2 retries actually happened so a future change that drops + // the retry policy fails this test. + require.GreaterOrEqual(t, attempts.Load(), int32(2), + "expected ReplicateBatch to be retried at least twice before failing") +} + +// TestSharded_TaskQueueReplicationFailure: when the +// SeedReplicationQueueWithUserDataEntries activity returns a +// non-retryable error, the child workflow signals the failure back +// and the parent fails with the seed error message; the status +// query reports the failure reason. Mirrors the existing +// TestTaskQueueReplicationFailure. +func TestSharded_TaskQueueReplicationFailure(t *testing.T) { + suite := &testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(ShardedForceReplicationWorkflow) + registerShardedScaffoldingWithSeed(env, 2, + func(_ context.Context, _ TaskQueueUserDataReplicationParamsWithNamespace) error { + return temporal.NewNonRetryableApplicationError("namespace is required", "InvalidArgument", nil) + }) + env.RegisterActivityWithOptions(pageThrough(nil, 1000), activity.RegisterOptions{Name: "ListWorkflows"}) + env.RegisterActivityWithOptions(func(_ context.Context, _ *shardedBatchReq) (replicateBatchResult, error) { + return replicateBatchResult{}, nil + }, activity.RegisterOptions{Name: "ReplicateBatch"}) + + env.ExecuteWorkflow(ShardedForceReplicationWorkflow, ShardedForceReplicationParams{ + Namespace: "test-ns", + TargetClusterName: "remote_cluster", + }) + + require.True(t, env.IsWorkflowCompleted()) + err := env.GetWorkflowError() + require.Error(t, err) + require.Contains(t, err.Error(), "namespace is required") + + envValue, qErr := env.QueryWorkflow(forceReplicationStatusQueryType) + require.NoError(t, qErr) + var status ForceReplicationStatus + require.NoError(t, envValue.Get(&status)) + require.True(t, status.TaskQueueUserDataReplicationStatus.Done) + require.Contains(t, status.TaskQueueUserDataReplicationStatus.FailureMessage, "namespace is required") +} + +// TestSharded_RecoveryBundle_OnBatchError: a batch returns a +// non-retryable error mid-cycle; the workflow latches lastErr, drains +// in-flight via cancellation, returns the error, and the status query +// reports a non-empty recovery bundle so an operator can start a +// fresh run with all unverified execs preserved. +func TestSharded_RecoveryBundle_OnBatchError(t *testing.T) { + suite := &testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(ShardedForceReplicationWorkflow) + registerShardedScaffolding(env, 4) + + // One page of execs across 4 shards. + execs := makeExecs(4, 5) // 20 execs total + env.RegisterActivityWithOptions(pageThrough(execs, 1000), activity.RegisterOptions{Name: "ListWorkflows"}) + + // Every batch fails non-retryably so lastErr latches on the first + // return and subsequent in-flight batches are cancelled. + env.RegisterActivityWithOptions(func(_ context.Context, _ *shardedBatchReq) (replicateBatchResult, error) { + return replicateBatchResult{}, temporal.NewNonRetryableApplicationError( + "batch failed", "BatchFailed", nil) + }, activity.RegisterOptions{Name: "ReplicateBatch"}) + + env.ExecuteWorkflow(ShardedForceReplicationWorkflow, ShardedForceReplicationParams{ + Namespace: "test-ns", + TargetClusterName: "remote_cluster", + }) + + require.True(t, env.IsWorkflowCompleted()) + err := env.GetWorkflowError() + require.Error(t, err) + var appErr *temporal.ApplicationError + require.ErrorAs(t, err, &appErr) + require.Equal(t, "BatchFailed", appErr.Type()) + + envValue, qErr := env.QueryWorkflow(forceReplicationStatusQueryType) + require.NoError(t, qErr) + var status ForceReplicationStatus + require.NoError(t, envValue.Get(&status)) + + // Recovery bundle: the cancelled in-flight batches go into + // RecoveryBuckets (collectRecoveredBuckets on batchExecs). The + // failed batch's execs land there too — its activity attempt + // errored after running, so batchExecs[id] is still populated. + require.NotEmpty(t, status.RecoveryBuckets, + "failed run must expose its in-flight execs as RecoveryBuckets") + recovered := 0 + for _, byBID := range status.RecoveryBuckets { + for _, runs := range byBID { + recovered += len(runs) + } + } + require.Equal(t, len(execs), recovered, + "every listed exec should be recoverable via the bundle") +} + +// TestSharded_RecoveryBundle_PreservesUndispatchedResumeShards: when +// the first resume batch errors and latches lastErr, the remaining +// undispatched resume entries must be preserved in the recovery +// bundle rather than silently dropped. +func TestSharded_RecoveryBundle_PreservesUndispatchedResumeShards(t *testing.T) { + suite := &testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(ShardedForceReplicationWorkflow) + registerShardedScaffolding(env, 8) + env.RegisterActivityWithOptions(pageThrough(nil, 1000), activity.RegisterOptions{Name: "ListWorkflows"}) + + // 8 shards × 60 execs each = 480 unverified. BatchSize=100 → + // dispatchResumeBatches plans ~5 batches; the first one fails, + // latches lastErr, and the remaining 4 must be preserved. + resumeShards := make([]ResumeShard, 8) + for s := range 8 { + resumeShards[s] = ResumeShard{ + Shard: int32(s), + Execs: makeExecsForShard(int32(s), 60), + } + } + + env.RegisterActivityWithOptions(func(_ context.Context, _ *shardedBatchReq) (replicateBatchResult, error) { + return replicateBatchResult{}, temporal.NewNonRetryableApplicationError( + "resume batch failed", "ResumeFailed", nil) + }, activity.RegisterOptions{Name: "ReplicateBatch"}) + + env.ExecuteWorkflow(ShardedForceReplicationWorkflow, ShardedForceReplicationParams{ + Namespace: "test-ns", + TargetClusterName: "remote_cluster", + ConcurrentBatchCount: 1, // serialize so we can observe the early-bail behaviour + ResumeShards: resumeShards, + }) + + require.True(t, env.IsWorkflowCompleted()) + require.Error(t, env.GetWorkflowError()) + + envValue, qErr := env.QueryWorkflow(forceReplicationStatusQueryType) + require.NoError(t, qErr) + var status ForceReplicationStatus + require.NoError(t, envValue.Get(&status)) + + // Every shard must show up exactly once across RecoveryResumeShards + // (still-undispatched, batched-but-cancelled, or the failed batch + // itself — all paths fold back into the recovery bundle). + recoveredRuns := 0 + for _, rs := range status.RecoveryResumeShards { + for _, runs := range rs.Execs { + recoveredRuns += len(runs) + } + } + for _, byBID := range status.RecoveryBuckets { + for _, runs := range byBID { + recoveredRuns += len(runs) + } + } + require.Equal(t, 8*60, recoveredRuns, + "all 480 resume execs should be recoverable; got %d", recoveredRuns) +} + +// TestSharded_RecoveryBundle_TracksCurrentPageToken: a List error +// after the first page should preserve the page token of the next +// page to read, not the start-of-run token. +func TestSharded_RecoveryBundle_TracksCurrentPageToken(t *testing.T) { + suite := &testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(ShardedForceReplicationWorkflow) + registerShardedScaffolding(env, 2) + + // First call succeeds; second call errors. Workflow processes page + // 1 successfully, then fails listing page 2 with NextPageToken + // pointing past page 1. + var listCalls atomic.Int32 + page1 := makeExecs(2, 3) + env.RegisterActivityWithOptions(func(_ context.Context, req *workflowservice.ListWorkflowExecutionsRequest) (*listWorkflowsResponse, error) { + n := listCalls.Add(1) + if n == 1 { + return &listWorkflowsResponse{ + Executions: page1, + NextPageToken: []byte("page-2"), + }, nil + } + return nil, temporal.NewNonRetryableApplicationError( + "list page 2 failed", "ListFailed", nil) + }, activity.RegisterOptions{Name: "ListWorkflows"}) + + // Batches succeed so page 1 doesn't pollute the recovery bundle — + // we want the page-token assertion isolated. + env.RegisterActivityWithOptions(func(_ context.Context, _ *shardedBatchReq) (replicateBatchResult, error) { + return replicateBatchResult{}, nil + }, activity.RegisterOptions{Name: "ReplicateBatch"}) + + env.ExecuteWorkflow(ShardedForceReplicationWorkflow, ShardedForceReplicationParams{ + Namespace: "test-ns", + TargetClusterName: "remote_cluster", + }) + + require.True(t, env.IsWorkflowCompleted()) + require.Error(t, env.GetWorkflowError()) + + envValue, qErr := env.QueryWorkflow(forceReplicationStatusQueryType) + require.NoError(t, qErr) + var status ForceReplicationStatus + require.NoError(t, envValue.Get(&status)) + + require.Equal(t, []byte("page-2"), status.RecoveryNextPageToken, + "RecoveryNextPageToken should reflect where listing was about to resume, not start-of-run") +} + +// TestWrapBatchVerifyError_RoundTrip: the workflow must be able to +// recover the partial VerifiedCount that the activity encoded on a +// failed batch return. Covers retryability preservation, the +// no-progress short-circuit, and inner-error reachability via +// Unwrap (so consumers that care about the underlying Type can walk +// past the wrapper). +func TestWrapBatchVerifyError_RoundTrip(t *testing.T) { + t.Run("non-zero count survives wrap; inner reachable via Unwrap", func(t *testing.T) { + cause := temporal.NewNonRetryableApplicationError("stuck", "ShardNoProgress", nil) + wrapped := wrapBatchVerifyError(cause, 42) + + // Outer wrapper carries the partial-verify tag and the count. + var appErr *temporal.ApplicationError + require.ErrorAs(t, wrapped, &appErr) + require.Equal(t, batchVerifyPartialErrorType, appErr.Type()) + require.True(t, appErr.NonRetryable()) + require.Equal(t, int64(42), extractVerifiedCountFromError(wrapped)) + + // Inner identity is preserved via Cause / Unwrap — consumers + // that key off the underlying type still get there. + var inner *temporal.ApplicationError + require.ErrorAs(t, appErr.Unwrap(), &inner) + require.Equal(t, "ShardNoProgress", inner.Type()) + }) + + t.Run("zero count returns cause unchanged", func(t *testing.T) { + cause := temporal.NewApplicationError("trivial", "X") + require.Same(t, cause, wrapBatchVerifyError(cause, 0)) + require.Equal(t, int64(0), extractVerifiedCountFromError(cause)) + }) + + t.Run("non-ApplicationError cause is wrapped and remains reachable", func(t *testing.T) { + cause := errors.New("plain failure") + wrapped := wrapBatchVerifyError(cause, 7) + var appErr *temporal.ApplicationError + require.ErrorAs(t, wrapped, &appErr) + require.Equal(t, batchVerifyPartialErrorType, appErr.Type()) + require.Equal(t, int64(7), extractVerifiedCountFromError(wrapped)) + require.ErrorIs(t, wrapped, cause) + }) + + t.Run("unrelated ApplicationError returns 0", func(t *testing.T) { + // An ApplicationError that wasn't produced by wrapBatchVerifyError + // — extractVerifiedCountFromError must not pull garbage out of it. + other := temporal.NewApplicationError("plain", "Other") + require.Equal(t, int64(0), extractVerifiedCountFromError(other)) + }) +} + +// TestSharded_PartialVerifiedCount_RecordedOnError: when a batch +// errors out after partial verify, the wrapped error's count must be +// folded into ReplicatedWorkflowCount so the status query reflects +// work actually done rather than zeroing out a partially-successful +// batch. +func TestSharded_PartialVerifiedCount_RecordedOnError(t *testing.T) { + suite := &testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(ShardedForceReplicationWorkflow) + registerShardedScaffolding(env, 2) + env.RegisterActivityWithOptions(pageThrough(makeExecs(2, 5), 1000), activity.RegisterOptions{Name: "ListWorkflows"}) + + const partialCount int64 = 6 + env.RegisterActivityWithOptions(func(_ context.Context, _ *shardedBatchReq) (replicateBatchResult, error) { + return replicateBatchResult{}, wrapBatchVerifyError( + temporal.NewNonRetryableApplicationError("simulated mid-verify failure", "Simulated", nil), + partialCount, + ) + }, activity.RegisterOptions{Name: "ReplicateBatch"}) + + env.ExecuteWorkflow(ShardedForceReplicationWorkflow, ShardedForceReplicationParams{ + Namespace: "test-ns", + TargetClusterName: "remote_cluster", + }) + + require.True(t, env.IsWorkflowCompleted()) + require.Error(t, env.GetWorkflowError()) + + envValue, qErr := env.QueryWorkflow(forceReplicationStatusQueryType) + require.NoError(t, qErr) + var status ForceReplicationStatus + require.NoError(t, envValue.Get(&status)) + + require.GreaterOrEqual(t, status.ReplicatedWorkflowCount, partialCount, + "failed batch's partial count should still be reflected in ReplicatedWorkflowCount") +} + +// ---- internal helpers ---- + +// makeExecsForShard produces `count` runs for the named shard's +// ResumeShard.Execs payload. Each run gets a distinct businessID so +// the resulting map has `count` entries with one run each — the +// simplest shape for tests that don't care about BID-reuse. +func makeExecsForShard(shard int32, count int) map[string][]RunEntry { + out := map[string][]RunEntry{} + for i := range count { + bid := "wf-" + strconv.Itoa(int(shard)) + "-" + strconv.Itoa(i) + out[bid] = []RunEntry{{RunID: "run-" + strconv.Itoa(int(shard)*1000+i)}} + } + return out +} diff --git a/tests/xdc/failover_test.go b/tests/xdc/failover_test.go index 7795083a4e1..08e944d6322 100644 --- a/tests/xdc/failover_test.go +++ b/tests/xdc/failover_test.go @@ -2707,6 +2707,256 @@ func (s *FunctionalClustersTestSuite) TestForceMigration_ResetWorkflow() { verifyHistory(workflowID, resp.GetRunId()) } +// TestForceMigration_ClosedWorkflow_Sharded is the sharded-variant +// duplicate of TestForceMigration_ClosedWorkflow. Kept intentionally +// close to the original — same workflow IDs prefixed with "sharded-", +// same assertions — so a future swap from legacy to sharded +// force-replication can drop the legacy version and keep behaviour +// coverage intact. +func (s *FunctionalClustersTestSuite) TestForceMigration_ClosedWorkflow_Sharded() { + testCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + namespace := s.createNamespaceInCluster0(true) + + taskqueue := "functional-local-force-replication-sharded-task-queue" + client0, worker0 := s.newClientAndWorker(s.clusters[0].Host().FrontendGRPCAddress(), namespace, taskqueue, "worker0") + + testWorkflowFn := func(ctx workflow.Context) error { + return nil + } + + worker0.RegisterWorkflow(testWorkflowFn) + s.NoError(worker0.Start()) + defer worker0.Stop() + + // Start wf1 + workflowID := "sharded-force-replication-test-wf-1" + run1, err := client0.ExecuteWorkflow(testCtx, sdkclient.StartWorkflowOptions{ + ID: workflowID, + TaskQueue: taskqueue, + WorkflowRunTimeout: time.Second * 30, + }, testWorkflowFn) + + s.NoError(err) + s.NotEmpty(run1.GetRunID()) + s.logger.Info("start wf1", tag.WorkflowRunID(run1.GetRunID())) + // wait until wf1 complete + err = run1.Get(testCtx, nil) + s.NoError(err) + + // Update ns to have 2 clusters + s.updateNamespaceClusters(namespace, 0, s.clusters) + + // Wait for wf1 to be indexed before force-replication. + s.waitForVisibilityCount(testCtx, namespace, 1) + + // Start force-replicate wf — sharded variant lives on its own + // dedicated task queue (MigrationShardedActivityTQ) and is + // registered under workflow name "force-replication-sharded". + sysClient, err := sdkclient.Dial(sdkclient.Options{ + HostPort: s.clusters[0].Host().FrontendGRPCAddress(), + Namespace: "temporal-system", + }) + s.NoError(err) + forceReplicationWorkflowID := "sharded-force-replication-wf" + sysWfRun, err := sysClient.ExecuteWorkflow(testCtx, sdkclient.StartWorkflowOptions{ + ID: forceReplicationWorkflowID, + TaskQueue: primitives.MigrationShardedActivityTQ, + WorkflowRunTimeout: time.Second * 30, + }, "force-replication-sharded", migration.ShardedForceReplicationParams{ + Namespace: namespace, + TargetClusterName: s.clusters[1].ClusterName(), + }) + s.NoError(err) + err = sysWfRun.Get(testCtx, nil) + s.NoError(err) + + // Verify all wf in ns is now available in cluster2 + client1, worker1 := s.newClientAndWorker(s.clusters[1].Host().FrontendGRPCAddress(), namespace, taskqueue, "worker1") + verify := func(wfID string, expectedRunID string) { + await.RequireTruef(s.T(), func() bool { + desc1, err := client1.DescribeWorkflowExecution(testCtx, wfID, "") + if err != nil { + return false + } + return desc1.WorkflowExecutionInfo.Execution.RunId == expectedRunID && + desc1.WorkflowExecutionInfo.Status == enumspb.WORKFLOW_EXECUTION_STATUS_COMPLETED + }, 15*time.Second, 200*time.Millisecond, "workflow %s should be replicated to cluster2", wfID) + } + verify(workflowID, run1.GetRunID()) + + s.failover(namespace, 0, s.clusters[1].ClusterName(), 2) + + worker1.RegisterWorkflow(testWorkflowFn) + s.NoError(worker1.Start()) + defer worker1.Stop() + + // Test reset workflow in cluster1 + resetResp, err := client1.ResetWorkflowExecution(testCtx, &workflowservice.ResetWorkflowExecutionRequest{ + Namespace: namespace, + WorkflowExecution: &commonpb.WorkflowExecution{ + WorkflowId: workflowID, + RunId: run1.GetRunID(), + }, + Reason: "force-replication-sharded-test", + WorkflowTaskFinishEventId: 3, + RequestId: uuid.NewString(), + }) + s.NoError(err) + + resetRun := client1.GetWorkflow(testCtx, workflowID, resetResp.GetRunId()) + err = resetRun.Get(testCtx, nil) + s.NoError(err) + + await.RequireTruef(s.T(), func() bool { + descResp, err := client1.DescribeWorkflowExecution(testCtx, workflowID, resetResp.GetRunId()) + if err != nil { + return false + } + return descResp.GetWorkflowExecutionInfo().Status == enumspb.WORKFLOW_EXECUTION_STATUS_COMPLETED + }, 15*time.Second, 200*time.Millisecond, "reset workflow should be visible on cluster2") +} + +// TestForceMigration_ResetWorkflow_Sharded is the sharded-variant +// duplicate of TestForceMigration_ResetWorkflow. Same intent as the +// legacy version: replicate a (reset → completed) workflow pair across +// clusters and confirm both runs are visible on the target. Asserts +// the activity-level verification count by walking the workflow's +// history for "ReplicateBatch" activity completions (the sharded +// activity name, replacing legacy "VerifyReplicationTasks"). +func (s *FunctionalClustersTestSuite) TestForceMigration_ResetWorkflow_Sharded() { + testCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + namespace := s.createNamespaceInCluster0(true) + + taskqueue := "functional-force-replication-sharded-reset-task-queue" + client0, worker0 := s.newClientAndWorker(s.clusters[0].Host().FrontendGRPCAddress(), namespace, taskqueue, "worker0") + + testWorkflowFn := func(ctx workflow.Context) error { + return nil + } + + worker0.RegisterWorkflow(testWorkflowFn) + s.NoError(worker0.Start()) + defer worker0.Stop() + + // Start wf1 + workflowID := "sharded-force-replication-test-reset-wf-1" + run1, err := client0.ExecuteWorkflow(testCtx, sdkclient.StartWorkflowOptions{ + ID: workflowID, + TaskQueue: taskqueue, + WorkflowRunTimeout: time.Second * 30, + }, testWorkflowFn) + + s.NoError(err) + s.NotEmpty(run1.GetRunID()) + s.logger.Info("start wf1", tag.WorkflowRunID(run1.GetRunID())) + // wait until wf1 complete + err = run1.Get(testCtx, nil) + s.NoError(err) + + resp, err := client0.ResetWorkflowExecution(testCtx, &workflowservice.ResetWorkflowExecutionRequest{ + Namespace: namespace, + WorkflowExecution: &commonpb.WorkflowExecution{ + WorkflowId: workflowID, + RunId: run1.GetRunID(), + }, + Reason: "test", + WorkflowTaskFinishEventId: 3, + RequestId: uuid.NewString(), + }) + s.NoError(err) + resetRun := client0.GetWorkflow(testCtx, workflowID, resp.GetRunId()) + err = resetRun.Get(testCtx, nil) + s.NoError(err) + + // Update ns to have 2 clusters + s.updateNamespaceClusters(namespace, 0, s.clusters) + + // Wait for both workflow runs (original + reset) to be indexed before force-replication. + s.waitForVisibilityCount(testCtx, namespace, 2) + + // Start force-replicate wf + sysClient, err := sdkclient.Dial(sdkclient.Options{ + HostPort: s.clusters[0].Host().FrontendGRPCAddress(), + Namespace: "temporal-system", + }) + s.NoError(err) + forceReplicationWorkflowID := "sharded-force-replication-wf" + sysWfRun, err := sysClient.ExecuteWorkflow(testCtx, sdkclient.StartWorkflowOptions{ + ID: forceReplicationWorkflowID, + TaskQueue: primitives.MigrationShardedActivityTQ, + WorkflowRunTimeout: time.Second * 30, + }, "force-replication-sharded", migration.ShardedForceReplicationParams{ + Namespace: namespace, + TargetClusterName: s.clusters[1].ClusterName(), + }) + s.NoError(err) + err = sysWfRun.Get(testCtx, nil) + s.NoError(err) + + // Verify the force-replication workflow actually ran ReplicateBatch + // activities (the sharded activity name; legacy is + // VerifyReplicationTasks) and that VerifiedCount sums to the + // expected number of workflow runs. + var totalVerifiedCount int64 + scheduledActivityTypes := make(map[int64]string) // scheduledEventId -> activity type name + histIter := sysClient.GetWorkflowHistory(testCtx, forceReplicationWorkflowID, sysWfRun.GetRunID(), + false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) + for histIter.HasNext() { + event, err := histIter.Next() + s.NoError(err) + switch event.GetEventType() { + case enumspb.EVENT_TYPE_ACTIVITY_TASK_SCHEDULED: + attrs := event.GetActivityTaskScheduledEventAttributes() + scheduledActivityTypes[event.GetEventId()] = attrs.GetActivityType().GetName() + case enumspb.EVENT_TYPE_ACTIVITY_TASK_COMPLETED: + attrs := event.GetActivityTaskCompletedEventAttributes() + activityType := scheduledActivityTypes[attrs.GetScheduledEventId()] + if activityType != "ReplicateBatch" { + continue + } + result := attrs.GetResult() + if result != nil && len(result.GetPayloads()) > 0 { + // Mirrors replicateBatchResult.VerifiedCount on the + // activity-side struct. Anonymous shape avoids + // importing the activity package's internal type. + var resp struct { + VerifiedCount int64 + } + s.NoError(payloads.Decode(result, &resp)) + totalVerifiedCount += resp.VerifiedCount + } + default: + } + } + // Expect exactly 2 verified workflow runs: original run + reset run + s.Equal(int64(2), totalVerifiedCount, + "sharded force-replication should have verified exactly 2 workflow runs (original + reset run)") + + s.waitForClusterSynced() + + // Verify all wf in ns is now available in cluster2 + client1, _ := s.newClientAndWorker(s.clusters[1].Host().FrontendGRPCAddress(), namespace, taskqueue, "worker1") + verifyHistory := func(wfID string, runID string) { + iter1 := client0.GetWorkflowHistory(testCtx, wfID, runID, false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) + iter2 := client1.GetWorkflowHistory(testCtx, wfID, runID, false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) + for iter1.HasNext() && iter2.HasNext() { + event1, err := iter1.Next() + s.NoError(err) + event2, err := iter2.Next() + s.NoError(err) + s.Equal(event1, event2) + } + s.False(iter1.HasNext()) + s.False(iter2.HasNext()) + } + verifyHistory(workflowID, run1.GetRunID()) + verifyHistory(workflowID, resp.GetRunId()) +} + func (s *FunctionalClustersTestSuite) TestBlockNamespaceDeleteInPassiveCluster() { namespace := s.createGlobalNamespace() diff --git a/tests/xdc/user_data_replication_test.go b/tests/xdc/user_data_replication_test.go index b49f6be6698..a8b18bc915f 100644 --- a/tests/xdc/user_data_replication_test.go +++ b/tests/xdc/user_data_replication_test.go @@ -499,6 +499,152 @@ func (s *UserDataReplicationTestSuite) TestUserDataEntriesAreReplicatedOnDemand( } } +// TestUserDataEntriesAreReplicatedOnDemand_Sharded is the +// sharded-variant duplicate of TestUserDataEntriesAreReplicatedOnDemand. +// Same intent — confirm that running the force-replication workflow +// pushes every task-queue user data entry onto the namespace +// replication queue — but exercised through the sharded workflow +// (registered name "force-replication-sharded" on +// MigrationShardedActivityTQ). DisableVerification:true mirrors the +// legacy test's EnableVerification:false default; TargetClusterName is +// still set (validation requires it, and the workflow fetches the +// remote shard count even when there are no execs to replicate) but +// the inject path is never reached here since this test only exercises +// the task-queue-user-data side of force-replication. +func (s *UserDataReplicationTestSuite) TestUserDataEntriesAreReplicatedOnDemand_Sharded() { + ctx := testcore.NewContext() + activeFrontendClient := s.clusters[0].FrontendClient() + adminClient := s.clusters[0].AdminClient() + numTaskQueues := 10 + + replicationResponse, err := adminClient.GetNamespaceReplicationMessages(ctx, &adminservice.GetNamespaceReplicationMessagesRequest{ + ClusterName: "follower", + LastRetrievedMessageId: -1, + LastProcessedMessageId: -1, + }) + s.NoError(err) + lastMessageID := replicationResponse.GetMessages().GetLastRetrievedMessageId() + + namespace := s.createNamespaceInCluster0(true) + description, err := activeFrontendClient.DescribeNamespace(testcore.NewContext(), &workflowservice.DescribeNamespaceRequest{Namespace: namespace}) + s.NoError(err) + + expectedReplicatedTaskQueues := make(map[string]struct{}, numTaskQueues) + for i := range numTaskQueues { + taskQueue := fmt.Sprintf("v1q%v", i) + res, err := activeFrontendClient.UpdateWorkerBuildIdCompatibility(ctx, &workflowservice.UpdateWorkerBuildIdCompatibilityRequest{ + Namespace: namespace, + TaskQueue: taskQueue, + Operation: &workflowservice.UpdateWorkerBuildIdCompatibilityRequest_AddNewBuildIdInNewDefaultSet{ + AddNewBuildIdInNewDefaultSet: "v0.1", + }, + }) + s.NoError(err) + s.NotNil(res) + expectedReplicatedTaskQueues[taskQueue] = struct{}{} + + taskQueue2 := fmt.Sprintf("v2q%v", i) + rules, err := activeFrontendClient.GetWorkerVersioningRules(ctx, &workflowservice.GetWorkerVersioningRulesRequest{ + Namespace: namespace, + TaskQueue: taskQueue2, + }) + s.NoError(err) + s.NotNil(rules) + + rulesRes, err := activeFrontendClient.UpdateWorkerVersioningRules(ctx, &workflowservice.UpdateWorkerVersioningRulesRequest{ + Namespace: namespace, + TaskQueue: taskQueue2, + ConflictToken: rules.ConflictToken, + Operation: &workflowservice.UpdateWorkerVersioningRulesRequest_InsertAssignmentRule{ + InsertAssignmentRule: &workflowservice.UpdateWorkerVersioningRulesRequest_InsertBuildIdAssignmentRule{ + Rule: &taskqueuepb.BuildIdAssignmentRule{ + TargetBuildId: "asdf", + }, + }, + }, + }) + s.NoError(err) + s.NotNil(rulesRes) + expectedReplicatedTaskQueues[taskQueue2] = struct{}{} + } + + // update namespace to cross clusters + s.updateNamespaceClusters(namespace, 0, s.clusters) + + // we should see one new namespace task in the replication queue + replicationResponse, err = adminClient.GetNamespaceReplicationMessages(ctx, &adminservice.GetNamespaceReplicationMessagesRequest{ + ClusterName: "follower", + LastRetrievedMessageId: lastMessageID, + LastProcessedMessageId: -1, + }) + s.NoError(err) + lastMessageID = replicationResponse.GetMessages().GetLastRetrievedMessageId() + s.Len(replicationResponse.GetMessages().ReplicationTasks, 1) + task := replicationResponse.GetMessages().ReplicationTasks[0] + s.Equal(namespace, task.GetNamespaceTaskAttributes().GetInfo().GetName()) + + // start sharded force-replicate wf + sysClient, err := sdkclient.Dial(sdkclient.Options{ + HostPort: s.clusters[0].Host().FrontendGRPCAddress(), + Namespace: primitives.SystemLocalNamespace, + }) + s.NoError(err) + run, err := sysClient.ExecuteWorkflow(ctx, sdkclient.StartWorkflowOptions{ + ID: "sharded-force-replication-wf", + TaskQueue: primitives.MigrationShardedActivityTQ, + WorkflowRunTimeout: time.Second * 30, + }, "force-replication-sharded", migration.ShardedForceReplicationParams{ + Namespace: namespace, + TargetClusterName: s.clusters[1].ClusterName(), + DisableVerification: true, + }) + s.NoError(err) + err = run.Get(ctx, nil) + s.NoError(err) + + replicationResponse, err = adminClient.GetNamespaceReplicationMessages(ctx, &adminservice.GetNamespaceReplicationMessagesRequest{ + ClusterName: "follower", + LastRetrievedMessageId: lastMessageID, + LastProcessedMessageId: -1, + }) + s.NoError(err) + + // we should see a user data task for all task queues + seenTaskQueues := make(map[string]struct{}, numTaskQueues) + for _, task := range replicationResponse.GetMessages().ReplicationTasks { + if attrs := task.GetTaskQueueUserDataAttributes(); attrs.GetNamespaceId() == description.GetNamespaceInfo().Id { + seenTaskQueues[attrs.GetTaskQueueName()] = struct{}{} + } + } + s.Equal(expectedReplicatedTaskQueues, seenTaskQueues) + + // failover and check on the other side + s.failover(namespace, 0, s.clusters[1].ClusterName(), 2) + + activeFrontendClient = s.clusters[1].FrontendClient() + for i := range numTaskQueues { + taskQueue := fmt.Sprintf("v1q%v", i) + + get, err := activeFrontendClient.GetWorkerBuildIdCompatibility(ctx, &workflowservice.GetWorkerBuildIdCompatibilityRequest{ + Namespace: namespace, + TaskQueue: taskQueue, + }) + s.NoError(err) + s.NotNil(get) + + s.NotEmpty(get.MajorVersionSets) + + taskQueue2 := fmt.Sprintf("v2q%v", i) + rules, err := activeFrontendClient.GetWorkerVersioningRules(ctx, &workflowservice.GetWorkerVersioningRulesRequest{ + Namespace: namespace, + TaskQueue: taskQueue2, + }) + s.NoError(err) + s.NotNil(rules) + s.NotEmpty(rules.AssignmentRules) + } +} + func (s *UserDataReplicationTestSuite) TestUserDataTombstonesAreReplicated() { s.T().SkipNow() // flaky test ctx := testcore.NewContext()