diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json index e623d3373a93..50d17c108f2e 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run!", - "modification": 1, + "modification": 2, } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json index e623d3373a93..50d17c108f2e 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run!", - "modification": 1, + "modification": 2, } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowExecutionContext.java index 6ff05b4b4452..888e954c1c9f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowExecutionContext.java @@ -150,6 +150,10 @@ boolean isSinkFullHintSet() { // the state size might grow unbounded. } + protected final long getBytesSinked() { + return bytesSinked; + } + /** * Sets a flag to indicate that a sink has enough data written to it. This hint is read by * upstream producers to stop producing if they can. Mainly used in streaming. diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 9e82343474c6..d03167540a88 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -181,7 +181,7 @@ public final class StreamingDataflowWorker { "windmill_bounded_queue_executor_use_fair_monitor"; // Don't use. Experiment guarding multi key bundles. The feature is work in progress and // incomplete. - private static final String UNSTABLE_ENABLE_MULTI_KEY_BUNDLE = "unstable_enable_multi_key_bundle"; + public static final String UNSTABLE_ENABLE_MULTI_KEY_BUNDLE = "unstable_enable_multi_key_bundle"; private final WindmillStateCache stateCache; private AtomicReference statusPages = new AtomicReference<>(); @@ -257,6 +257,7 @@ private StreamingDataflowWorker( this.streamingWorkScheduler = StreamingWorkScheduler.create( options, + DataflowRunner.hasExperiment(options, UNSTABLE_ENABLE_MULTI_KEY_BUNDLE), clock, readerCache, mapTaskExecutorFactory, @@ -1198,9 +1199,14 @@ private void onCompleteCommit(CompleteCommit completeCommit) { computationStateCache .getIfPresent(completeCommit.computationId()) .ifPresent( - state -> + state -> { + if (completeCommit.retryableFailure()) { + state.reExecuteActiveWork(completeCommit.shardedKey(), completeCommit.workId()); + } else { state.completeWorkAndScheduleNextWorkForKey( - completeCommit.shardedKey(), completeCommit.workId())); + completeCommit.shardedKey(), completeCommit.workId()); + } + }); } @AutoValue diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 00fdf67b8d02..c40eed196a10 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -18,12 +18,14 @@ package org.apache.beam.runners.dataflow.worker; import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import com.google.api.services.dataflow.model.CounterUpdate; import com.google.api.services.dataflow.model.SideInputInfo; import java.io.Closeable; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; @@ -33,6 +35,8 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; @@ -45,10 +49,11 @@ import org.apache.beam.runners.core.metrics.ExecutionStateTracker; import org.apache.beam.runners.core.metrics.ExecutionStateTracker.ExecutionState; import org.apache.beam.runners.dataflow.worker.DataflowOperationContext.DataflowExecutionState; -import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.StepContext; import org.apache.beam.runners.dataflow.worker.counters.CounterFactory; import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.ProfileScope; +import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; +import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; @@ -56,6 +61,10 @@ import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInput; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputState; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; +import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter; +import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; import org.apache.beam.runners.dataflow.worker.util.common.worker.WorkExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataId; @@ -75,6 +84,8 @@ import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; import org.apache.beam.sdk.metrics.MetricsContainer; +import org.apache.beam.sdk.options.ExperimentalOptions; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; @@ -109,9 +120,16 @@ @SuppressWarnings({"deprecation"}) @NotThreadSafe @Internal -public class StreamingModeExecutionContext extends DataflowExecutionContext { +public class StreamingModeExecutionContext + extends DataflowExecutionContext { private static final Logger LOG = LoggerFactory.getLogger(StreamingModeExecutionContext.class); + private static final String WINDMILL_MAX_KEY_GROUP_BATCH_SIZE = + "windmill_max_key_group_batch_size"; + private static final String WINDMILL_MAX_KEY_GROUP_BATCH_TIME_MS = + "windmill_max_key_group_batch_time_ms"; + private static final String WINDMILL_MAX_KEY_GROUP_BATCH_SINK_BYTES = + "windmill_max_key_group_batch_sink_bytes"; private final String computationId; private final ImmutableMap stateNameMap; @@ -141,6 +159,7 @@ public class StreamingModeExecutionContext extends DataflowExecutionContext keyCoder; + + // Key switch listener to delegate MDC logging context and thread name updates + public interface KeyTransitionListener { + void onKeyTransition(@Nullable Work oldWork, Work newWork); + } + + @SuppressWarnings("UnusedVariable") + private @Nullable KeyTransitionListener keyTransitionListener; + + private List executedWorks = new ArrayList<>(); + private List outputBuilders = new ArrayList<>(); + + // Map> + private Map> accumulatedCallbacks = new HashMap<>(); + private final AtomicBoolean workBatchFailed = new AtomicBoolean(false); + private @Nullable WindmillStateReader activeStateReader; + private long stateBytesRead = 0; + private final String sourceBytesProcessCounterName; + + private final int maxKeyGroupBatchSize; + private final long maxKeyGroupBatchTimeNanos; + private final boolean multiKeyBundleEnabled; + private final long maxKeyGroupBatchSinkBytes; + private int workItemsPolled = 0; + private long bundleStartTimeNanos = 0; + public StreamingModeExecutionContext( CounterFactory counterFactory, String computationId, @@ -170,7 +225,13 @@ public StreamingModeExecutionContext( StreamingModeExecutionStateRegistry executionStateRegistry, StreamingGlobalConfigHandle globalConfigHandle, long sinkByteLimit, - boolean throwExceptionOnLargeOutput) { + boolean throwExceptionOnLargeOutput, + HotKeyLogger hotKeyLogger, + boolean hotKeyLoggingEnabled, + String stepName, + String sourceBytesProcessCounterName, + PipelineOptions options, + SideInputStateFetcherFactory sideInputStateFetcherFactory) { super( counterFactory, metricsContainerRegistry, @@ -185,6 +246,37 @@ public StreamingModeExecutionContext( this.stateCache = stateCache; this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; this.throwExceptionOnLargeOutput = throwExceptionOnLargeOutput; + this.hotKeyLogger = checkNotNull(hotKeyLogger); + this.hotKeyLoggingEnabled = hotKeyLoggingEnabled; + this.stepName = checkNotNull(stepName); + this.sourceBytesProcessCounterName = checkNotNull(sourceBytesProcessCounterName); + this.sideInputStateFetcherFactory = checkNotNull(sideInputStateFetcherFactory); + + // Initialize batch limits from pipeline options + this.maxKeyGroupBatchSize = + tryParseInt( + ExperimentalOptions.getExperimentValue(options, WINDMILL_MAX_KEY_GROUP_BATCH_SIZE), + 100, + WINDMILL_MAX_KEY_GROUP_BATCH_SIZE); + + long batchTimeMs = + tryParseLong( + ExperimentalOptions.getExperimentValue(options, WINDMILL_MAX_KEY_GROUP_BATCH_TIME_MS), + 100, + WINDMILL_MAX_KEY_GROUP_BATCH_TIME_MS); + this.maxKeyGroupBatchTimeNanos = TimeUnit.MILLISECONDS.toNanos(batchTimeMs); + + this.multiKeyBundleEnabled = + ExperimentalOptions.hasExperiment( + options, StreamingDataflowWorker.UNSTABLE_ENABLE_MULTI_KEY_BUNDLE); + + this.maxKeyGroupBatchSinkBytes = + tryParseLong( + ExperimentalOptions.getExperimentValue( + options, WINDMILL_MAX_KEY_GROUP_BATCH_SINK_BYTES), + StreamingDataflowWorker.MAX_SINK_BYTES, + WINDMILL_MAX_KEY_GROUP_BATCH_SINK_BYTES); + StreamingGlobalConfig config = globalConfigHandle.getConfig(); this.operationalLimits = config.operationalLimits(); this.windmillTagEncoding = @@ -193,6 +285,41 @@ public StreamingModeExecutionContext( : WindmillTagEncodingV1.instance(); } + private static int tryParseInt(@Nullable String value, int defaultValue, String experimentName) { + if (value == null) { + return defaultValue; + } + try { + return Integer.parseInt(value); + } catch (NumberFormatException e) { + LOG.warn( + "Failed to parse experiment {} value '{}' as integer, falling back to default: {}", + experimentName, + value, + defaultValue, + e); + return defaultValue; + } + } + + private static long tryParseLong( + @Nullable String value, long defaultValue, String experimentName) { + if (value == null) { + return defaultValue; + } + try { + return Long.parseLong(value); + } catch (NumberFormatException e) { + LOG.warn( + "Failed to parse experiment {} value '{}' as long, falling back to default: {}", + experimentName, + value, + defaultValue, + e); + return defaultValue; + } + } + @VisibleForTesting public final long getBacklogBytes() { return backlogBytes; @@ -211,7 +338,7 @@ public boolean throwExceptionsForLargeOutput() { } public boolean workIsFailed() { - return work != null && work.isFailed(); + return workBatchFailed.get(); } public boolean getDrainMode() { @@ -243,43 +370,113 @@ public byte[] getCurrentRecordOffset() { return checkStateNotNull(activeReader).getCurrentRecordOffset(); } + public void clear() { + for (Work w : executedWorks) { + w.setOnFailureListener(null); + } + this.executedWorks = new ArrayList<>(); + this.outputBuilders = new ArrayList<>(); + this.accumulatedCallbacks = new HashMap<>(); + this.workBatchFailed.set(false); + this.sideInputCache.clear(); + this.activeStateReader = null; + this.activeReader = null; + this.keyCoder = null; + this.workExecutor = null; + this.workQueueExecutor = null; + this.budgetHandle = null; + this.keyTransitionListener = null; + this.work = null; + this.key = null; + this.outputBuilder = null; + this.sideInputStateFetcher = null; + this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; + clearSinkFullHint(); + this.stateBytesRead = 0; + } + public void start( - @Nullable Object key, Work work, - WindmillStateReader stateReader, - SideInputStateFetcher sideInputStateFetcher, - Windmill.WorkItemCommitRequest.Builder outputBuilder, - WorkExecutor workExecutor) { - this.key = key; - this.work = work; + WorkExecutor workExecutor, + BoundedQueueExecutor workQueueExecutor, + BoundedQueueExecutorWorkHandle budgetHandle, + @Nullable Coder keyCoder, + KeyTransitionListener keyTransitionListener) { + clear(); + this.keyCoder = keyCoder; this.workExecutor = workExecutor; - this.finishKeyCalled = false; - this.computationKey = WindmillComputationKey.create(computationId, work.getShardedKey()); - this.sideInputStateFetcher = sideInputStateFetcher; + this.workQueueExecutor = workQueueExecutor; + this.budgetHandle = budgetHandle; + this.keyTransitionListener = keyTransitionListener; + + this.workItemsPolled = 1; + this.bundleStartTimeNanos = System.nanoTime(); + StreamingGlobalConfig config = globalConfigHandle.getConfig(); // Snapshot the limits for entire bundle processing. this.operationalLimits = config.operationalLimits(); - this.outputBuilder = outputBuilder; - this.sideInputCache.clear(); - this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; - clearSinkFullHint(); - Instant processingTime = computeProcessingTime(work.getWorkItem().getTimers().getTimersList()); + startForNewKey(work); + } - Collection stepContexts = getAllStepContexts(); - if (!stepContexts.isEmpty()) { - // This must be only created once for the workItem as token validation will fail if the same - // work token is reused. - WindmillStateCache.ForKey cacheForKey = - stateCache.forKey(getComputationKey(), getWorkItem().getCacheToken(), getWorkToken()); - for (StepContext stepContext : stepContexts) { - stepContext.start(stateReader, processingTime, cacheForKey, work.watermarks()); + private @Nullable Object decodeKey(Work work) { + // If the read output KVs, then we can decode Windmill's byte key into userland + // key object and provide it to the execution context for use with per-key state. + // Otherwise, we pass null. + // + // The coder type that will be present is: + // WindowedValueCoder(TimerOrElementCoder(KvCoder)) + if (keyCoder != null) { + try { + return keyCoder.decode(work.getWorkItem().getKey().newInput(), Coder.Context.OUTER); + } catch (IOException e) { + throw new RuntimeException("Failed to decode key during processing", e); + } + } + return null; + } + + private Windmill.WorkItemCommitRequest.Builder createOutputBuilder(Work work) { + return Windmill.WorkItemCommitRequest.newBuilder() + .setKey(work.getWorkItem().getKey()) + .setShardingKey(work.getWorkItem().getShardingKey()) + .setWorkToken(work.getWorkItem().getWorkToken()) + .setCacheToken(work.getWorkItem().getCacheToken()); + } + + private void logHotKeyIfDetected(Work work, @Nullable Object decodedKey) { + if (work.getWorkItem().hasHotKeyInfo()) { + Windmill.HotKeyInfo hotKeyInfo = work.getWorkItem().getHotKeyInfo(); + Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 1000); + if (decodedKey != null && hotKeyLoggingEnabled) { + hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, decodedKey); + } else { + hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge); } } } + private void startStepContexts( + WindmillStateReader stateReader, + Instant processingTime, + WindmillStateCache.ForKey cacheForKey, + Watermarks watermarks) { + Collection stepContexts = getAllStepContexts(); + for (StepContext stepContext : stepContexts) { + stepContext.start(stateReader, processingTime, cacheForKey, watermarks); + } + } + public void finishKey() { - checkState(!finishKeyCalled, "finishKey was already called"); + if (finishKeyCalled) { + return; + } + if (activeStateReader != null) { + this.stateBytesRead += activeStateReader.getBytesRead(); + } + if (sideInputStateFetcher != null) { + this.stateBytesRead += sideInputStateFetcher.getBytesRead(); + } checkStateNotNull(workExecutor, "workExecutor must be set before calling finishKey()"); try { workExecutor.finishKey(); @@ -287,6 +484,8 @@ public void finishKey() { throw new RuntimeException(e); } this.finishKeyCalled = true; + + flushStateInternal(); } /** @@ -440,20 +639,23 @@ public void setActiveReader(UnboundedReader reader) { /** Invalidate the state and reader caches for this computation and key. */ public void invalidateCache() { - ByteString key = getSerializedKey(); - if (key != null) { - readerCache.invalidateReader(getComputationKey()); - if (activeReader != null) { - try { - activeReader.close(); - } catch (IOException e) { - LOG.warn( - "Failed to close reader for {}-{}", computationId, getWorkItem().getShardingKey(), e); - } + for (Work w : executedWorks) { + WindmillComputationKey compKey = + WindmillComputationKey.create(computationId, w.getShardedKey()); + readerCache.invalidateReader(compKey); + stateCache.invalidate(w.getShardedKey()); + } + if (activeReader != null) { + try { + activeReader.close(); + } catch (IOException e) { + Windmill.WorkItem workItem = getWorkItem(); + long shardingKey = workItem != null ? workItem.getShardingKey() : -1L; + LOG.warn("Failed to close reader for {}-{}", computationId, shardingKey, e); } - activeReader = null; - stateCache.invalidate(key, getWorkItem().getShardingKey()); } + activeReader = null; + activeStateReader = null; } public UnboundedSource.@Nullable CheckpointMark getReaderCheckpoint( @@ -469,8 +671,7 @@ public void invalidateCache() { } } - public Map> flushState() { - checkState(finishKeyCalled, "finishKey must be called before flushState"); + private void flushStateInternal() { Map> callbacks = new HashMap<>(); for (StepContext stepContext : getAllStepContexts()) { @@ -553,7 +754,136 @@ public Map> flushState() { // RestrictionTracker.getProgress() or GetSize() are not defined. getOutputBuilder().setSourceBacklogBytes(backlogBytes); } - return callbacks; + + this.accumulatedCallbacks.putAll(callbacks); + + getOutputBuilder() + .setSourceBytesProcessed(computeSourceBytesProcessed(sourceBytesProcessCounterName)); + } + + private final long computeSourceBytesProcessed(String sourceBytesCounterName) { + if (!(workExecutor instanceof DataflowMapTaskExecutor)) { + return 0L; + } + HashMap counters = + ((DataflowMapTaskExecutor) workExecutor) + .getReadOperation() + .receivers[0] + .getOutputCounters(); + + return Optional.ofNullable(counters.get(sourceBytesCounterName)) + .map(counter -> ((OutputObjectAndByteCounter) counter).getByteCount().getAndReset()) + .orElse(0L); + } + + public Map> flushState() { + return accumulatedCallbacks; + } + + public boolean advance() { + if (!multiKeyBundleEnabled) { + return false; + } + if (workIsFailed()) { + throw new WorkItemCancelledException(checkStateNotNull(work).getWorkItem().getShardingKey()); + } + + BoundedQueueExecutor executor = checkStateNotNull(workQueueExecutor); + BoundedQueueExecutorWorkHandle handle = checkStateNotNull(budgetHandle); + Work activeWork = checkStateNotNull(work); + + if (activeWork.getKeyGroup().equals(Work.KeyGroup.DEFAULT) || shouldStopBatching()) { + return false; + } + + @Nullable + ExecutableWork additionalWork = + executor.pollWork(computationId, activeWork.getKeyGroup(), handle); + if (additionalWork != null) { + Work newWork = additionalWork.work(); + ++workItemsPolled; + checkStateNotNull(keyTransitionListener).onKeyTransition(activeWork, newWork); + startForNewKey(newWork); + return true; + } + + return false; + } + + private boolean shouldStopBatching() { + if (workItemsPolled >= maxKeyGroupBatchSize) { + return true; + } + long elapsedNanos = System.nanoTime() - bundleStartTimeNanos; + if (elapsedNanos >= maxKeyGroupBatchTimeNanos) { + return true; + } + return getBytesSinked() >= maxKeyGroupBatchSinkBytes; + } + + private void startForNewKey(Work newWork) { + newWork.setState(Work.State.PROCESSING); + this.key = decodeKey(newWork); + this.work = newWork; + this.finishKeyCalled = false; + this.computationKey = WindmillComputationKey.create(computationId, newWork.getShardedKey()); + + this.outputBuilder = createOutputBuilder(newWork); + this.outputBuilders.add(this.outputBuilder); + newWork.setOnFailureListener(this.workBatchFailed); + this.executedWorks.add(newWork); + + logHotKeyIfDetected(newWork, this.key); + + this.sideInputStateFetcher = + sideInputStateFetcherFactory.createSideInputStateFetcher(newWork::fetchSideInput); + this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; + this.activeReader = null; + + // Note: We do NOT clear sideInputCache here, allowing Key B to reuse warm side inputs! + + // Re-initialize state cache and state/timer internals across all step contexts + Instant processingTime = + computeProcessingTime(newWork.getWorkItem().getTimers().getTimersList()); + if (!getAllStepContexts().isEmpty()) { + // This must be only created once for a workItem as token validation will fail if the same + // work token is reused. + WindmillStateCache.ForKey cacheForKey = + stateCache.forKey( + getComputationKey(), newWork.getWorkItem().getCacheToken(), getWorkToken()); + this.activeStateReader = newWork.createWindmillStateReader(this::workIsFailed); + startStepContexts(this.activeStateReader, processingTime, cacheForKey, newWork.watermarks()); + } else { + this.activeStateReader = null; + } + } + + public List getExecutedWorks() { + return executedWorks; + } + + public long getStateBytesRead() { + return stateBytesRead; + } + + public List getWorkItemCommits() { + List commits = new ArrayList<>(outputBuilders.size()); + for (Windmill.WorkItemCommitRequest.Builder builder : outputBuilders) { + commits.add(builder.build()); + } + return commits; + } + + public Map> getAccumulatedCallbacks() { + return accumulatedCallbacks; + } + + public @Nullable Object getKey() { + return key; + } + + public Work getWork() { + return checkStateNotNull(work); } @Nullable diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java index 075a1a8a4250..134655a72a54 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java @@ -35,9 +35,9 @@ public abstract class WindmillReaderIteratorBase extends NativeReader.NativeReaderIterator> { private final StreamingModeExecutionContext context; - private final Windmill.WorkItem work; - private int bundleIndex = 0; - private int messageIndex = -1; + private Windmill.WorkItem work; + private int bundleIndex; + private int messageIndex; private @Nullable WindowedValue current = null; private final ValueProvider skipUndecodableElements; private static final Logger LOG = LoggerFactory.getLogger(WindmillReaderIteratorBase.class); @@ -47,6 +47,8 @@ protected WindmillReaderIteratorBase( this.context = context; this.skipUndecodableElements = skipUndecodableElements; this.work = context.getWorkItem(); + this.bundleIndex = 0; + this.messageIndex = -1; } @Override @@ -57,15 +59,25 @@ public boolean start() throws IOException { @Override public boolean advance() throws IOException { if (context.workIsFailed()) { - throw new WorkItemCancelledException(context.getWorkItem().getShardingKey()); + throw new WorkItemCancelledException(checkNotNull(context.getWorkItem()).getShardingKey()); } while (true) { if (bundleIndex >= work.getMessageBundlesCount()) { - current = null; + // If elements are exhausted, try advancing the execution context to the next key in the + // group context.finishKey(); + if (context.advance()) { + // Transition succeeded! Update iterator references to the new work item + resetWorkFromContext(); + continue; + } + + // All work items are exhausted. + current = null; return false; } + Windmill.InputMessageBundle bundle = work.getMessageBundles(bundleIndex); ++messageIndex; if (messageIndex >= bundle.getMessagesCount()) { @@ -73,6 +85,7 @@ public boolean advance() throws IOException { ++bundleIndex; continue; } + try { current = checkNotNull(decodeMessage(bundle.getMessages(messageIndex))); return true; @@ -91,6 +104,12 @@ public boolean advance() throws IOException { } } + private void resetWorkFromContext() { + this.work = context.getWorkItem(); + this.bundleIndex = 0; + this.messageIndex = -1; + } + protected abstract WindowedValue decodeMessage(Windmill.Message message) throws IOException; @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java index 488684769bd9..2003ec001a55 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java @@ -30,7 +30,6 @@ import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.worker.util.ValueInEmptyWindows; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; @@ -49,7 +48,6 @@ @Internal class WindowingWindmillReader extends NativeReader>> { - private final Coder keyCoder; private final Coder valueCoder; private final Coder windowCoder; private final Coder> windowsCoder; @@ -66,7 +64,6 @@ class WindowingWindmillReader extends NativeReader keyedWorkItemCoder = (WindmillKeyedWorkItem.FakeKeyedWorkItemCoder) inputCoder.getValueCoder(); - this.keyCoder = keyedWorkItemCoder.getKeyCoder(); this.valueCoder = keyedWorkItemCoder.getElementCoder(); this.context = context; this.skipUndecodableElements = skipUndecodableElements; @@ -129,73 +126,86 @@ public static WindowingWindmillReader create( return new WindowingWindmillReader<>(coder, context, skipUndecodableElements); } + private KeyedWorkItem createKeyedWorkItem() { + @SuppressWarnings("unchecked") + @Nullable + K key = (K) context.getKey(); + return new WindmillKeyedWorkItem<>( + key, + context.getWorkItem(), + windowCoder, + windowsCoder, + valueCoder, + context.getWindmillTagEncoding(), + context.getDrainMode(), + skipUndecodableElements.isAccessible() + && Boolean.TRUE.equals(skipUndecodableElements.get())); + } + + private boolean isEmpty(KeyedWorkItem keyedWorkItem) { + return Iterables.isEmpty(keyedWorkItem.timersIterable()) + && Iterables.isEmpty(keyedWorkItem.elementsIterable()); + } + @Override public NativeReaderIterator>> iterator() throws IOException { - final K key = - keyCoder.decode( - checkStateNotNull(context.getSerializedKey()).newInput(), Coder.Context.OUTER); - final WorkItem workItem = context.getWorkItem(); - KeyedWorkItem keyedWorkItem = - new WindmillKeyedWorkItem<>( - key, - workItem, - windowCoder, - windowsCoder, - valueCoder, - context.getWindmillTagEncoding(), - context.getDrainMode(), - skipUndecodableElements.isAccessible() - && Boolean.TRUE.equals(skipUndecodableElements.get())); - final boolean isEmptyWorkItem = - (Iterables.isEmpty(keyedWorkItem.timersIterable()) - && Iterables.isEmpty(keyedWorkItem.elementsIterable())); - final WindowedValue> value = new ValueInEmptyWindows<>(keyedWorkItem); - - // Return a noop iterator when current workitem is an empty workitem. - if (isEmptyWorkItem) { - return new NativeReaderIterator>>() { - @Override - public boolean start() throws IOException { - context.finishKey(); - return false; + final KeyedWorkItem firstKeyedWorkItem = createKeyedWorkItem(); + final boolean firstKeyIsEmpty = isEmpty(firstKeyedWorkItem); + final WindowedValue> firstValue = + new ValueInEmptyWindows<>(firstKeyedWorkItem); + + return new NativeReaderIterator>>() { + private @Nullable WindowedValue> current = null; + private boolean started = false; + + @Override + public boolean start() throws IOException { + if (context.workIsFailed()) { + throw new WorkItemCancelledException( + checkStateNotNull(context.getWorkItem()).getShardingKey()); } - - @Override - public boolean advance() throws IOException { + if (started) { return false; } - - @Override - public WindowedValue> getCurrent() { - throw new NoSuchElementException(); + started = true; + if (firstKeyIsEmpty) { + return advance(); // Try to transition immediately if the first key is empty! } - }; - } else { - return new NativeReaderIterator>>() { - private @Nullable WindowedValue> current = null; - - @Override - public boolean start() throws IOException { - current = value; - return true; + current = firstValue; + return true; + } + + @Override + public boolean advance() throws IOException { + if (context.workIsFailed()) { + throw new WorkItemCancelledException( + checkStateNotNull(context.getWorkItem()).getShardingKey()); } - @Override - public boolean advance() throws IOException { - current = null; + while (true) { context.finishKey(); + if (context.advance()) { + KeyedWorkItem newKeyedWorkItem = createKeyedWorkItem(); + if (isEmpty(newKeyedWorkItem)) { + continue; + } + current = new ValueInEmptyWindows<>(newKeyedWorkItem); + return true; + } + + current = null; return false; } + } - @Override - public WindowedValue> getCurrent() { - if (current == null) { - throw new NoSuchElementException(); - } - return value; + @Override + public WindowedValue> getCurrent() { + if (current == null) { + throw new NoSuchElementException(); } - }; - } + return current; + } + }; } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkCancelingException.java similarity index 54% rename from runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidException.java rename to runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkCancelingException.java index 29b16b71883f..73a307641b96 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkCancelingException.java @@ -17,21 +17,30 @@ */ package org.apache.beam.runners.dataflow.worker; -import javax.annotation.Nullable; +import org.checkerframework.checker.nullness.qual.Nullable; -/** Indicates that the key token was invalid when data was attempted to be fetched. */ -public class KeyTokenInvalidException extends RuntimeException { - public KeyTokenInvalidException(String key) { - super("Unable to fetch data due to token mismatch for key " + key); +/** + * Indicates that the work is no longer valid and should be canceled. It is thrown as a signal for + * upper layers to mark the work as failed. + */ +public class WorkCancelingException extends RuntimeException { + + public WorkCancelingException(long sharding_key) { + super("Work canceling exception for key " + sharding_key); + } + + public WorkCancelingException(Throwable cause) { + super(cause); } - /** Returns whether an exception was caused by a {@link KeyTokenInvalidException}. */ - public static boolean isKeyTokenInvalidException(@Nullable Throwable t) { - while (t != null) { - if (t instanceof KeyTokenInvalidException) { + /** Returns whether an exception was caused by a {@link WorkCancelingException}. */ + public static boolean isWorkCancelingException(Throwable t) { + @Nullable Throwable throwable = t; + while (throwable != null) { + if (throwable instanceof WorkCancelingException) { return true; } - t = t.getCause(); + throwable = throwable.getCause(); } return false; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java index a12a5075c5ee..68cbab32254c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java @@ -17,31 +17,10 @@ */ package org.apache.beam.runners.dataflow.worker; -/** Indicates that the work item was cancelled and should not be retried. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) +/** Indicates that the work item was canceled. */ public class WorkItemCancelledException extends RuntimeException { + public WorkItemCancelledException(long sharding_key) { super("Work item cancelled for key " + sharding_key); } - - public WorkItemCancelledException(String message, Throwable cause) { - super(message, cause); - } - - public WorkItemCancelledException(Throwable cause) { - super(cause); - } - - /** Returns whether an exception was caused by a {@link WorkItemCancelledException}. */ - public static boolean isWorkItemCancelledException(Throwable t) { - while (t != null) { - if (t instanceof WorkItemCancelledException) { - return true; - } - t = t.getCause(); - } - return false; - } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java index e430f6c8f638..f49aa31a439a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java @@ -88,6 +88,11 @@ static ActiveWorkState create(WindmillStateCache.ForComputation computationState return new ActiveWorkState(new HashMap<>(), computationStateCache); } + synchronized Optional getActiveWork(ShardedKey shardedKey, WorkId workId) { + LinkedHashMap workQueue = activeWork.get(shardedKey.shardingKey()); + return workQueue == null ? Optional.empty() : Optional.ofNullable(workQueue.get(workId)); + } + @VisibleForTesting static ActiveWorkState forTesting( Map> activeWork, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/BoundedQueueExecutorWorkHandle.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/BoundedQueueExecutorWorkHandle.java index 1ca534966947..20661aae0a04 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/BoundedQueueExecutorWorkHandle.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/BoundedQueueExecutorWorkHandle.java @@ -17,8 +17,13 @@ */ package org.apache.beam.runners.dataflow.worker.streaming; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; + /** * A handle to use when requesting pulling more work from @BoundedQueueExecutor * via @BoundedQueueExecutor.pollWork */ -public interface BoundedQueueExecutorWorkHandle {} +public interface BoundedQueueExecutorWorkHandle { + // Returns all work that are tracked by the handle + ImmutableList getWorkBatch(); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java index 3886d4fbc01b..e9f6ddc55de6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java @@ -131,6 +131,10 @@ public void completeWorkAndScheduleNextWorkForKey(ShardedKey shardedKey, WorkId .ifPresent(this::forceExecute); } + public void reExecuteActiveWork(ShardedKey shardedKey, WorkId workId) { + activeWorkState.getActiveWork(shardedKey, workId).ifPresent(this::forceExecute); + } + public void invalidateStuckCommits(Instant stuckCommitDeadline) { activeWorkState.invalidateStuckCommits( stuckCommitDeadline, this::completeWorkAndScheduleNextWorkForKey); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java index b4f3a22a7f52..5208ee475f47 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java @@ -18,21 +18,15 @@ package org.apache.beam.runners.dataflow.worker.streaming; import com.google.auto.value.AutoValue; -import java.util.HashMap; import java.util.Optional; import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.runners.core.metrics.ExecutionStateTracker; -import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutor; import org.apache.beam.runners.dataflow.worker.DataflowWorkExecutor; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; -import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter; -import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.KeyTransitionListener; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.Coder; -import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -68,13 +62,19 @@ public static ComputationWorkExecutor.Builder builder() { * Executes DoFns for the Work. Blocks the calling thread until DoFn(s) have completed execution. */ public final void executeWork( - @Nullable Object key, Work work, - WindmillStateReader stateReader, - SideInputStateFetcher sideInputStateFetcher, - Windmill.WorkItemCommitRequest.Builder outputBuilder) + BoundedQueueExecutor workQueueExecutor, + BoundedQueueExecutorWorkHandle budgetHandle, + KeyTransitionListener keyTransitionListener) throws Exception { - context().start(key, work, stateReader, sideInputStateFetcher, outputBuilder, workExecutor()); + context() + .start( + work, + workExecutor(), + workQueueExecutor, + budgetHandle, + keyCoder().orElse(null), + keyTransitionListener); workExecutor().execute(); } @@ -84,6 +84,7 @@ public final void executeWork( */ public final void invalidate() { context().invalidateCache(); + context().clear(); try { workExecutor().close(); } catch (Exception e) { @@ -91,18 +92,6 @@ public final void invalidate() { } } - public final long computeSourceBytesProcessed(String sourceBytesCounterName) { - HashMap counters = - ((DataflowMapTaskExecutor) workExecutor()) - .getReadOperation() - .receivers[0] - .getOutputCounters(); - - return Optional.ofNullable(counters.get(sourceBytesCounterName)) - .map(counter -> ((OutputObjectAndByteCounter) counter).getByteCount().getAndReset()) - .orElse(0L); - } - @AutoValue.Builder public abstract static class Builder { public abstract Builder setWorkExecutor(DataflowWorkExecutor workExecutor); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index 53ed30fdedbb..f9cfec7e6807 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -28,12 +28,15 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Supplier; import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.worker.ActiveMessageMetadata; import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; +import org.apache.beam.runners.dataflow.worker.WorkCancelingException; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; @@ -82,7 +85,10 @@ public final class Work implements RefreshableWork { private volatile TimedState currentState; private volatile boolean isFailed; private volatile String processingThreadName = ""; + private final AtomicReference<@Nullable AtomicBoolean> onFailureListener = + new AtomicReference<>(null); private final boolean drainMode; + private ImmutableList getWorkStreamLatencies; private Work( WorkItem workItem, @@ -90,7 +96,8 @@ private Work( Watermarks watermarks, ProcessingContext processingContext, boolean drainMode, - Supplier clock) { + Supplier clock, + ImmutableList getWorkStreamLatencies) { this.shardedKey = ShardedKey.create(workItem.getKey(), workItem.getShardingKey()); this.workItem = workItem; this.serializedWorkItemSize = serializedWorkItemSize; @@ -114,6 +121,7 @@ private Work( + Long.toHexString(workItem.getWorkToken()); this.currentState = TimedState.initialState(startTime); this.isFailed = false; + this.getWorkStreamLatencies = getWorkStreamLatencies; } public static Work create( @@ -124,7 +132,31 @@ public static Work create( boolean drainMode, Supplier clock) { return new Work( - workItem, serializedWorkItemSize, watermarks, processingContext, drainMode, clock); + workItem, + serializedWorkItemSize, + watermarks, + processingContext, + drainMode, + clock, + ImmutableList.of()); + } + + public static Work create( + WorkItem workItem, + long serializedWorkItemSize, + Watermarks watermarks, + ProcessingContext processingContext, + boolean drainMode, + Supplier clock, + ImmutableList getWorkStreamLatencies) { + return new Work( + workItem, + serializedWorkItemSize, + watermarks, + processingContext, + drainMode, + clock, + getWorkStreamLatencies); } public static ProcessingContext createProcessingContext( @@ -191,17 +223,41 @@ public long getSerializedWorkItemSize() { return serializedWorkItemSize; } + public String getComputationId() { + return processingContext.computationId(); + } + @Override public ShardedKey getShardedKey() { return shardedKey; } public Optional fetchKeyedState(KeyedGetDataRequest keyedGetDataRequest) { - return processingContext.fetchKeyedState(keyedGetDataRequest); + try { + Optional response = + processingContext.fetchKeyedState(keyedGetDataRequest); + if (response.isPresent() && response.get().getFailed()) { + // Work is not valid in backend anymore. + this.setFailed(); + } + return response; + } catch (RuntimeException e) { + if (WorkCancelingException.isWorkCancelingException(e)) { + this.setFailed(); + } + throw e; + } } public GlobalData fetchSideInput(GlobalDataRequest request) { - return processingContext.getDataClient().getSideInputData(request); + try { + return processingContext.getDataClient().getSideInputData(request); + } catch (RuntimeException e) { + if (WorkCancelingException.isWorkCancelingException(e)) { + this.setFailed(); + } + throw e; + } } public String backendWorkerToken() { @@ -244,6 +300,19 @@ public void setProcessingThreadName(String processingThreadName) { @Override public void setFailed() { this.isFailed = true; + AtomicBoolean listener = onFailureListener.get(); + if (listener != null) { + listener.set(true); + } + } + + // Sets the passed in boolean to true if the work fails + // Supports registering only one boolean at a time. + public void setOnFailureListener(@Nullable AtomicBoolean listener) { + onFailureListener.set(listener); + if (isFailed && listener != null) { + listener.set(true); + } } public boolean isCommitPending() { @@ -268,8 +337,12 @@ public void queueCommit(WorkItemCommitRequest commitRequest, ComputationState co processingContext.workCommitter().accept(Commit.create(commitRequest, computationState, this)); } - public WindmillStateReader createWindmillStateReader() { - return WindmillStateReader.forWork(this); + public Consumer workCommitter() { + return processingContext.workCommitter(); + } + + public WindmillStateReader createWindmillStateReader(Supplier workIsFailed) { + return WindmillStateReader.forWork(this, workIsFailed); } @Override @@ -277,11 +350,17 @@ public WorkId id() { return id; } - public void recordGetWorkStreamLatencies( - ImmutableList getWorkStreamLatencies) { - for (LatencyAttribution latency : getWorkStreamLatencies) { - totalDurationPerState.put( - latency.getState(), Duration.millis(latency.getTotalDurationMillis())); + public ImmutableList getWorkStreamLatencies() { + return getWorkStreamLatencies; + } + + public void recordGetWorkStreamLatencies() { + if (!getWorkStreamLatencies.isEmpty()) { + for (LatencyAttribution latency : getWorkStreamLatencies) { + totalDurationPerState.put( + latency.getState(), Duration.millis(latency.getTotalDurationMillis())); + } + this.getWorkStreamLatencies = ImmutableList.of(); } } @@ -390,10 +469,6 @@ private boolean isCommitPending() { abstract Instant startTime(); } - public String getComputationId() { - return processingContext.computationId(); - } - public KeyGroup getKeyGroup() { return keyGroup; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java index 8964246c1160..9eb9a37b1b76 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java @@ -20,6 +20,8 @@ import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadFactory; @@ -30,9 +32,9 @@ import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Work; -import org.apache.beam.runners.dataflow.worker.streaming.Work.KeyGroup; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor.Guard; import org.checkerframework.checker.nullness.qual.Nullable; @@ -260,7 +262,7 @@ final class BoundedQueueExecutorWorkHandleImpl implements BoundedQueueExecutorWorkHandle, AutoCloseable { @GuardedBy("this") - private int elements; + private final List workBatch; @GuardedBy("this") private long bytes; @@ -268,16 +270,17 @@ final class BoundedQueueExecutorWorkHandleImpl @GuardedBy("this") private boolean closed = false; - private BoundedQueueExecutorWorkHandleImpl(int elements, long bytes) { - checkArgument(elements >= 0 && bytes >= 0); - this.elements = elements; + private BoundedQueueExecutorWorkHandleImpl(Work work, long bytes) { + checkArgument(bytes >= 0); + this.workBatch = new ArrayList<>(); + this.workBatch.add(checkArgumentNotNull(work)); this.bytes = bytes; } /** * Merges the budget from another handle into this handle. * - *

This transfers the budget (elements and bytes) from the {@code other} handle to this + *

This transfers the budget (workBatch and bytes) from the {@code other} handle to this * handle, and marks the {@code other} handle as closed to prevent it from releasing the budget * again if it is closed. */ @@ -287,10 +290,10 @@ public void merge(BoundedQueueExecutorWorkHandleImpl other) { Preconditions.checkState(!closed, "Cannot merge into a closed handle"); synchronized (other) { Preconditions.checkState(!other.closed, "Cannot merge a closed handle"); - this.elements += other.elements; + this.workBatch.addAll(other.workBatch); this.bytes += other.bytes; other.closed = true; - other.elements = 0; + other.workBatch.clear(); other.bytes = 0; } } @@ -300,9 +303,9 @@ public synchronized boolean isClosed() { return closed; } - @VisibleForTesting - synchronized int elements() { - return elements; + @Override + public synchronized ImmutableList getWorkBatch() { + return ImmutableList.copyOf(workBatch); } @VisibleForTesting @@ -314,7 +317,7 @@ synchronized long bytes() { public synchronized void close() { if (closed) return; closed = true; - decrementCounters(this.elements, this.bytes); + decrementCounters(this.workBatch.size(), this.bytes); } } @@ -350,7 +353,7 @@ private void executeMonitorHeld(ExecutableWork work, long workBytes) { bytesOutstanding += workBytes; monitor.leave(); BoundedQueueExecutorWorkHandleImpl handle = - new BoundedQueueExecutorWorkHandleImpl(1, workBytes); + new BoundedQueueExecutorWorkHandleImpl(work.work(), workBytes); try { executor.execute(new QueuedWork(work, handle)); } catch (Throwable t) { @@ -379,14 +382,15 @@ private void executeMonitorHeld(Runnable work) { } @VisibleForTesting - BoundedQueueExecutorWorkHandleImpl createBudgetHandle(int elements, long bytes) { - return new BoundedQueueExecutorWorkHandleImpl(elements, bytes); + BoundedQueueExecutorWorkHandleImpl createBudgetHandle(Work work, long bytes) { + return new BoundedQueueExecutorWorkHandleImpl(work, bytes); } public @Nullable ExecutableWork pollWork( String computationId, Work.KeyGroup keyGroup, BoundedQueueExecutorWorkHandle handle) { + checkArgument( + computationId != null && keyGroup != null && !keyGroup.equals(Work.KeyGroup.DEFAULT)); checkArgument(handle instanceof BoundedQueueExecutorWorkHandleImpl); - checkArgument(computationId != null && keyGroup != null && !keyGroup.equals(KeyGroup.DEFAULT)); BoundedQueueExecutorWorkHandleImpl internalHandle = (BoundedQueueExecutorWorkHandleImpl) handle; if (keyGroupWorkQueue == null) { return null; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java index 526b67890783..36001c151508 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java @@ -108,6 +108,11 @@ boolean commitWorkItem( Windmill.WorkItemCommitRequest request, Consumer onDone); + boolean commitMultiKeyWorkItem( + String computation, + Windmill.MultiKeyWorkItemCommitRequest request, + Consumer onDone); + /** Flushes any pending work items to the wire. */ void flush(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java index b840d22a3434..e52a9846645f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java @@ -18,11 +18,14 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.commits; import com.google.auto.value.AutoValue; +import java.util.Optional; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; /** Value class for a queued commit. */ @Internal @@ -32,20 +35,43 @@ public abstract class Commit { public static Commit create( WorkItemCommitRequest request, ComputationState computationState, Work work) { Preconditions.checkArgument(request.getSerializedSize() > 0); - return new AutoValue_Commit(request, computationState, work); + return new AutoValue_Commit( + Optional.of(request), computationState, Optional.empty(), ImmutableList.of(work)); + } + + public static Commit createMultiKey( + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest, + ComputationState computationState, + ImmutableList workBatch) { + Preconditions.checkArgument(!workBatch.isEmpty()); + return new AutoValue_Commit( + Optional.empty(), computationState, Optional.of(multiKeyRequest), workBatch); } public final String computationId() { return computationState().getComputationId(); } - public abstract WorkItemCommitRequest request(); + public abstract Optional singleKeyRequest(); public abstract ComputationState computationState(); - public abstract Work work(); + public abstract Optional multiKeyRequest(); + + public abstract ImmutableList workBatch(); + + public final boolean isFailed() { + for (Work w : workBatch()) { + if (w.isFailed()) { + return true; + } + } + return false; + } public final int getSize() { - return request().getSerializedSize(); + return multiKeyRequest() + .map(Windmill.MultiKeyWorkItemCommitRequest::getSerializedSize) + .orElseGet(() -> singleKeyRequest().get().getSerializedSize()); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java index e33e853d3d76..e168d92987fb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java @@ -37,24 +37,14 @@ @AutoValue public abstract class CompleteCommit { - public static CompleteCommit create(Commit commit, CommitStatus commitStatus) { - return new AutoValue_CompleteCommit( - commit.computationId(), - ShardedKey.create(commit.request().getKey(), commit.request().getShardingKey()), - WorkId.builder() - .setWorkToken(commit.request().getWorkToken()) - .setCacheToken(commit.request().getCacheToken()) - .build(), - commitStatus); - } - public static CompleteCommit create( - String computationId, ShardedKey shardedKey, WorkId workId, CommitStatus status) { - return new AutoValue_CompleteCommit(computationId, shardedKey, workId, status); - } - - public static CompleteCommit forFailedWork(Commit commit) { - return create(commit, CommitStatus.ABORTED); + String computationId, + ShardedKey shardedKey, + WorkId workId, + CommitStatus status, + boolean retryableFailure) { + return new AutoValue_CompleteCommit( + computationId, shardedKey, workId, status, retryableFailure); } public abstract String computationId(); @@ -64,4 +54,6 @@ public static CompleteCommit forFailedWork(Commit commit) { public abstract WorkId workId(); public abstract CommitStatus status(); + + public abstract boolean retryableFailure(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java index 20b95b0661d0..58f0dbbea242 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.commits; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + import java.util.HashMap; import java.util.Map; import java.util.concurrent.ExecutorService; @@ -112,7 +114,8 @@ private void commitLoop() { } while (commit != null) { ComputationState computationState = commit.computationState(); - commit.work().setState(Work.State.COMMITTING); + checkState(commit.workBatch().size() == 1); + commit.workBatch().get(0).setState(Work.State.COMMITTING); Windmill.ComputationCommitWorkRequest.Builder computationRequestBuilder = computationRequestMap.get(computationState); if (computationRequestBuilder == null) { @@ -120,7 +123,8 @@ private void commitLoop() { computationRequestBuilder.setComputationId(computationState.getComputationId()); computationRequestMap.put(computationState, computationRequestBuilder); } - computationRequestBuilder.addRequests(commit.request()); + checkState(commit.singleKeyRequest().isPresent()); + computationRequestBuilder.addRequests(commit.singleKeyRequest().get()); // Send the request if we've exceeded the bytes or there is no more // pending work. commitBytes is a long, so this cannot overflow. commitBytes += commit.getSize(); @@ -155,7 +159,8 @@ private void completeWork( .setCacheToken(workRequest.getCacheToken()) .setWorkToken(workRequest.getWorkToken()) .build(), - Windmill.CommitStatus.OK)); + Windmill.CommitStatus.OK, + /* retryableFailure= */ false)); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java index b68f53121b86..72d9e5ed8d03 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java @@ -30,6 +30,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue; import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore; import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.sdk.annotations.Internal; @@ -100,7 +101,7 @@ public void start() { @Override public void commit(Commit commit) { - if (commit.work().isFailed()) { + if (commit.isFailed()) { failCommit(commit); } else { commitQueue.put(commit); @@ -113,8 +114,8 @@ public void commit(Commit commit) { "Trying to queue commit on shutdown, failing commit=[computationId={}, shardingKey={}," + " workId={} ].", commit.computationId(), - commit.work().getShardedKey(), - commit.work().id()); + commit.workBatch().get(0).getShardedKey(), + commit.workBatch().get(0).id()); drainCommitQueue(); } } @@ -147,8 +148,42 @@ private void drainCommitQueue() { } private void failCommit(Commit commit) { - commit.work().setFailed(); - onCommitComplete.accept(CompleteCommit.forFailedWork(commit)); + if (!isRunning.get()) { + // Shutting down, fail everything unconditionally to prevent infinite loops + for (Work w : commit.workBatch()) { + w.setFailed(); + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), + w.getShardedKey(), + w.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ false)); + } + return; + } + + // Still running, only fail actually failed work, and request re-execution for valid ones + for (Work w : commit.workBatch()) { + if (w.isFailed()) { + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), + w.getShardedKey(), + w.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ false)); + } else { + LOG.debug("Requesting re-execution for valid work {} from failed commit", w.id()); + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), + w.getShardedKey(), + w.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ true)); + } + } } @Override @@ -173,8 +208,8 @@ private void streamingCommitLoop() { // take() blocks until a value is available in the commitQueue. Preconditions.checkNotNull(initialCommit); - if (initialCommit.work().isFailed()) { - onCommitComplete.accept(CompleteCommit.forFailedWork(initialCommit)); + if (initialCommit.isFailed()) { + failCommit(initialCommit); initialCommit = null; continue; } @@ -202,20 +237,51 @@ private void streamingCommitLoop() { /** Adds the commit to the batch if it fits, returning true if it is consumed. */ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatcher batcher) { Preconditions.checkNotNull(commit); - commit.work().setState(Work.State.COMMITTING); + for (Work w : commit.workBatch()) { + w.setState(Work.State.COMMITTING); + } activeCommitBytes.addAndGet(commit.getSize()); - boolean isCommitAccepted = - batcher.commitWorkItem( - commit.computationId(), - commit.request(), - commitStatus -> { - onCommitComplete.accept(CompleteCommit.create(commit, commitStatus)); - activeCommitBytes.addAndGet(-commit.getSize()); - }); + boolean isCommitAccepted; + if (commit.multiKeyRequest().isPresent()) { + isCommitAccepted = + batcher.commitMultiKeyWorkItem( + commit.computationId(), + commit.multiKeyRequest().get(), + commitStatus -> { + for (Work w : commit.workBatch()) { + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), + w.getShardedKey(), + w.id(), + commitStatus, + /* retryableFailure= */ false)); + } + activeCommitBytes.addAndGet(-commit.getSize()); + }); + } else { + isCommitAccepted = + batcher.commitWorkItem( + commit.computationId(), + commit.singleKeyRequest().get(), + commitStatus -> { + Work w = commit.workBatch().get(0); + onCommitComplete.accept( + CompleteCommit.create( + commit.computationId(), + w.getShardedKey(), + w.id(), + commitStatus, + /* retryableFailure= */ false)); + activeCommitBytes.addAndGet(-commit.getSize()); + }); + } // Since the commit was not accepted, revert the changes made above. if (!isCommitAccepted) { - commit.work().setState(Work.State.COMMIT_QUEUED); + for (Work w : commit.workBatch()) { + w.setState(Work.State.COMMIT_QUEUED); + } activeCommitBytes.addAndGet(-commit.getSize()); } @@ -246,8 +312,8 @@ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatch } // Drop commits for failed work. Such commits will be dropped by Windmill anyway. - if (commit.work().isFailed()) { - onCommitComplete.accept(CompleteCommit.forFailedWork(commit)); + if (commit.isFailed()) { + failCommit(commit); continue; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java index ab12946ad18b..d233bf091b6a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java @@ -19,6 +19,7 @@ import java.io.PrintWriter; import java.util.function.Function; +import org.apache.beam.runners.dataflow.worker.WorkCancelingException; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; @@ -62,7 +63,7 @@ public Windmill.KeyedGetDataResponse getStateData( try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) { return getDataStream.requestKeyedData(computationId, request); } catch (WindmillStreamShutdownException e) { - throw new WorkItemCancelledException(request.getShardingKey()); + throw new WorkCancelingException(request.getShardingKey()); } catch (Exception e) { throw new GetDataException( "Error occurred fetching state for computation=" @@ -87,7 +88,7 @@ public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling()) { return sideInputGetDataStream.requestGlobalData(request); } catch (WindmillStreamShutdownException e) { - throw new WorkItemCancelledException(e); + throw new WorkCancelingException(e); } catch (Exception e) { throw new GetDataException( "Error occurred fetching side input for tag=" + request.getDataId(), e); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index d24676652fd8..afa736d7c3ad 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -35,6 +35,7 @@ import java.util.function.Function; import javax.annotation.Nullable; import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitRequestChunk; @@ -270,7 +271,7 @@ private void flushInternal(Map requests) if (requests.size() == 1) { Map.Entry elem = requests.entrySet().iterator().next(); - if (elem.getValue().request().getSerializedSize() + if (elem.getValue().serializedCommit().size() > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { issueMultiChunkRequest(elem.getKey(), elem.getValue()); } else { @@ -289,6 +290,7 @@ private void issueSingleRequest(long id, PendingRequest pendingRequest) .setComputationId(pendingRequest.computationId()) .setRequestId(id) .setShardingKey(pendingRequest.shardingKey()) + .setCommitType(pendingRequest.commitType()) .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); StreamingCommitWorkRequest chunk = requestBuilder.build(); synchronized (this) { @@ -318,7 +320,8 @@ private void issueBatchedRequest(Map requests) chunkBuilder .setRequestId(entry.getKey()) .setShardingKey(request.shardingKey()) - .setSerializedWorkItemCommit(request.serializedCommit()); + .setSerializedWorkItemCommit(request.serializedCommit()) + .setCommitType(request.commitType()); } StreamingCommitWorkRequest request = requestBuilder.build(); synchronized (this) { @@ -360,7 +363,8 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) .setRequestId(id) .setSerializedWorkItemCommit(chunk) .setComputationId(pendingRequest.computationId()) - .setShardingKey(pendingRequest.shardingKey()); + .setShardingKey(pendingRequest.shardingKey()) + .setCommitType(pendingRequest.commitType()); int remaining = serializedCommit.size() - end; if (remaining > 0) { chunkBuilder.setRemainingBytesForWorkItem(remaining); @@ -378,34 +382,34 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) @AutoValue abstract static class PendingRequest { - - private static PendingRequest create( - String computationId, WorkItemCommitRequest request, Consumer onDone) { - return new AutoValue_GrpcCommitWorkStream_PendingRequest(computationId, request, onDone); + static PendingRequest create( + String computationId, + long shardingKey, + ByteString serializedCommit, + StreamingCommitRequestChunk.CommitType commitType, + Consumer onDone) { + return new AutoValue_GrpcCommitWorkStream_PendingRequest( + computationId, shardingKey, serializedCommit, commitType, onDone); } abstract String computationId(); - abstract WorkItemCommitRequest request(); + abstract long shardingKey(); + + abstract ByteString serializedCommit(); + + abstract StreamingCommitRequestChunk.CommitType commitType(); abstract Consumer onDone(); private long getBytes() { - return (long) request().getSerializedSize() + computationId().length(); - } - - private ByteString serializedCommit() { - return request().toByteString(); + return (long) serializedCommit().size() + computationId().length(); } private void completeWithStatus(CommitStatus commitStatus) { onDone().accept(commitStatus); } - private long shardingKey() { - return request().getShardingKey(); - } - private void abort() { completeWithStatus(CommitStatus.ABORTED); } @@ -462,7 +466,34 @@ public boolean commitWorkItem( return false; } - PendingRequest request = PendingRequest.create(computation, commitRequest, onDone); + PendingRequest request = + PendingRequest.create( + computation, + commitRequest.getShardingKey(), + commitRequest.toByteString(), + StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_SINGLE_KEY, + onDone); + add(idGenerator.incrementAndGet(), request); + return true; + } + + @Override + public boolean commitMultiKeyWorkItem( + String computation, + Windmill.MultiKeyWorkItemCommitRequest commitRequest, + Consumer onDone) { + if (!canAccept(commitRequest.getSerializedSize() + computation.length())) { + return false; + } + Preconditions.checkArgument(commitRequest.getRequestsCount() > 0); + PendingRequest request = + PendingRequest.create( + computation, + // Any key in the batch for routing + commitRequest.getRequests(0).getShardingKey(), + commitRequest.toByteString(), + StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY, + onDone); add(idGenerator.incrementAndGet(), request); return true; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java index c609bed4eae0..6c5ae50858cc 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java @@ -36,8 +36,8 @@ import java.util.function.Supplier; import java.util.stream.Collectors; import javax.annotation.Nullable; -import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException; import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; +import org.apache.beam.runners.dataflow.worker.WorkCancelingException; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; @@ -153,7 +153,7 @@ static WindmillStateReader forTesting( fetchStateFromWindmillFn, key, shardingKey, workToken, () -> null, () -> Boolean.FALSE); } - public static WindmillStateReader forWork(Work work) { + public static WindmillStateReader forWork(Work work, Supplier workItemIsFailed) { return new WindmillStateReader( work::fetchKeyedState, work.getWorkItem().getKey(), @@ -163,7 +163,7 @@ public static WindmillStateReader forWork(Work work) { work.setState(Work.State.READING); return () -> work.setState(Work.State.PROCESSING); }, - work::isFailed); + workItemIsFailed); } private Future stateFuture(StateTag stateTag, @Nullable Coder coder) { @@ -588,7 +588,8 @@ private KeyedGetDataRequest createRequest(Iterable> toFetch) { private void consumeResponse(KeyedGetDataResponse response, Set> toFetch) { bytesRead += response.getSerializedSize(); if (response.getFailed()) { - throw new KeyTokenInvalidException(key.toStringUtf8()); + // upper layers will fail the work on seeing this exception. + throw new WorkCancelingException(shardingKey); } if (!key.equals(response.getKey())) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java index 269799903300..86449b1c2bb1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java @@ -29,6 +29,7 @@ import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutor; import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutorFactory; +import org.apache.beam.runners.dataflow.worker.HotKeyLogger; import org.apache.beam.runners.dataflow.worker.IntrinsicMapTaskExecutorFactory; import org.apache.beam.runners.dataflow.worker.ReaderCache; import org.apache.beam.runners.dataflow.worker.ReaderRegistry; @@ -48,6 +49,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.ComputationWorkExecutor; import org.apache.beam.runners.dataflow.worker.streaming.StageInfo; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.common.worker.MapTaskExecutor; import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; import org.apache.beam.runners.dataflow.worker.util.common.worker.ReadOperation; @@ -82,6 +84,7 @@ final class ComputationWorkExecutorFactory { private final SinkRegistry sinkRegistry; private final DataflowExecutionStateSampler sampler; private final CounterSet pendingDeltaCounters; + private final SideInputStateFetcherFactory sideInputStateFetcherFactory; /** * Function which converts map tasks to their network representation for execution. @@ -97,6 +100,7 @@ final class ComputationWorkExecutorFactory { private final IdGenerator idGenerator; private final StreamingGlobalConfigHandle globalConfigHandle; private final boolean throwExceptionOnLargeOutput; + private final HotKeyLogger hotKeyLogger; ComputationWorkExecutorFactory( DataflowWorkerHarnessOptions options, @@ -106,7 +110,9 @@ final class ComputationWorkExecutorFactory { DataflowExecutionStateSampler sampler, CounterSet pendingDeltaCounters, IdGenerator idGenerator, - StreamingGlobalConfigHandle globalConfigHandle) { + StreamingGlobalConfigHandle globalConfigHandle, + HotKeyLogger hotKeyLogger, + SideInputStateFetcherFactory sideInputStateFetcherFactory) { this.options = options; this.mapTaskExecutorFactory = mapTaskExecutorFactory; this.readerCache = readerCache; @@ -124,6 +130,8 @@ final class ComputationWorkExecutorFactory { : StreamingDataflowWorker.MAX_SINK_BYTES; this.throwExceptionOnLargeOutput = hasExperiment(options, THROW_EXCEPTIONS_ON_LARGE_OUTPUT_EXPERIMENT); + this.hotKeyLogger = hotKeyLogger; + this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; } private static Nodes.ParallelInstructionNode extractReadNode( @@ -191,8 +199,12 @@ ComputationWorkExecutor createComputationWorkExecutor( DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker = createExecutionStateTracker(stageInfo, mapTask, workLatencyTrackingId); + boolean hotKeyLoggingEnabled = + options.isHotKeyLoggingEnabled() || hasExperiment(options, "enable_hot_key_logging"); + String stepName = getShuffleTaskStepName(mapTask); StreamingModeExecutionContext context = - createExecutionContext(computationState, stageInfo, executionStateTracker); + createExecutionContext( + computationState, stageInfo, executionStateTracker, hotKeyLoggingEnabled, stepName); DataflowMapTaskExecutor mapTaskExecutor = createMapTaskExecutor(context, mapTask, mapTaskNetwork); ReadOperation readOperation = getValidatedReadOperation(mapTaskExecutor); @@ -255,7 +267,9 @@ ComputationWorkExecutor createComputationWorkExecutor( private StreamingModeExecutionContext createExecutionContext( ComputationState computationState, StageInfo stageInfo, - DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker) { + DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker, + boolean hotKeyLoggingEnabled, + String stepName) { String computationId = computationState.getComputationId(); return new StreamingModeExecutionContext( pendingDeltaCounters, @@ -268,7 +282,13 @@ private StreamingModeExecutionContext createExecutionContext( stageInfo.executionStateRegistry(), globalConfigHandle, maxSinkBytes, - throwExceptionOnLargeOutput); + throwExceptionOnLargeOutput, + hotKeyLogger, + hotKeyLoggingEnabled, + stepName, + computationState.sourceBytesProcessCounterName(), + options, + sideInputStateFetcherFactory); } private DataflowMapTaskExecutor createMapTaskExecutor( @@ -286,6 +306,12 @@ private DataflowMapTaskExecutor createMapTaskExecutor( idGenerator); } + private static String getShuffleTaskStepName(MapTask mapTask) { + // The MapTask instruction is ordered by dependencies, such that the first element is + // always going to be the shuffle task. + return mapTask.getInstructions().get(0).getName(); + } + private DataflowExecutionContext.DataflowExecutionStateTracker createExecutionStateTracker( StageInfo stageInfo, MapTask mapTask, String workLatencyTrackingId) { return new DataflowExecutionContext.DataflowExecutionStateTracker( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index 364608be82ca..664999f0d864 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -17,22 +17,25 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.processing; -import static org.apache.beam.sdk.options.ExperimentalOptions.hasExperiment; - import com.google.api.services.dataflow.model.MapTask; import com.google.auto.value.AutoValue; -import java.util.Optional; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.function.Function; import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutorFactory; import org.apache.beam.runners.dataflow.worker.HotKeyLogger; import org.apache.beam.runners.dataflow.worker.ReaderCache; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.KeyTransitionListener; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.logging.DataflowWorkerLoggingMDC; import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; @@ -45,7 +48,6 @@ import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.harness.StreamingCounters; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.ExceptionUtils; @@ -53,16 +55,14 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commit; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; -import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.FailureTracker; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.WorkFailureProcessor; import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.fn.IdGenerator; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; -import org.joda.time.Duration; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -78,48 +78,46 @@ public class StreamingWorkScheduler { private static final Logger LOG = LoggerFactory.getLogger(StreamingWorkScheduler.class); - private final DataflowWorkerHarnessOptions options; private final Supplier clock; private final ComputationWorkExecutorFactory computationWorkExecutorFactory; - private final SideInputStateFetcherFactory sideInputStateFetcherFactory; private final FailureTracker failureTracker; private final WorkFailureProcessor workFailureProcessor; private final StreamingCommitFinalizer commitFinalizer; private final StreamingCounters streamingCounters; - private final HotKeyLogger hotKeyLogger; private final ConcurrentMap stageInfoMap; private final DataflowExecutionStateSampler sampler; private final StreamingGlobalConfigHandle globalConfigHandle; + private final BoundedQueueExecutor workExecutor; + private final boolean multiKeyExperimentEnabled; public StreamingWorkScheduler( - DataflowWorkerHarnessOptions options, Supplier clock, + BoundedQueueExecutor workExecutor, ComputationWorkExecutorFactory computationWorkExecutorFactory, - SideInputStateFetcherFactory sideInputStateFetcherFactory, FailureTracker failureTracker, WorkFailureProcessor workFailureProcessor, StreamingCommitFinalizer commitFinalizer, StreamingCounters streamingCounters, - HotKeyLogger hotKeyLogger, ConcurrentMap stageInfoMap, DataflowExecutionStateSampler sampler, - StreamingGlobalConfigHandle globalConfigHandle) { - this.options = options; + StreamingGlobalConfigHandle globalConfigHandle, + boolean multiKeyExperimentEnabled) { this.clock = clock; + this.workExecutor = workExecutor; this.computationWorkExecutorFactory = computationWorkExecutorFactory; - this.sideInputStateFetcherFactory = sideInputStateFetcherFactory; this.failureTracker = failureTracker; this.workFailureProcessor = workFailureProcessor; this.commitFinalizer = commitFinalizer; this.streamingCounters = streamingCounters; - this.hotKeyLogger = hotKeyLogger; this.stageInfoMap = stageInfoMap; this.sampler = sampler; this.globalConfigHandle = globalConfigHandle; + this.multiKeyExperimentEnabled = multiKeyExperimentEnabled; } public static StreamingWorkScheduler create( DataflowWorkerHarnessOptions options, + boolean multiKeyExperimentEnabled, Supplier clock, ReaderCache readerCache, DataflowMapTaskExecutorFactory mapTaskExecutorFactory, @@ -134,6 +132,9 @@ public static StreamingWorkScheduler create( IdGenerator idGenerator, StreamingGlobalConfigHandle globalConfigHandle, ConcurrentMap stageInfoMap) { + SideInputStateFetcherFactory sideInputStateFetcherFactory = + SideInputStateFetcherFactory.fromOptions(options); + ComputationWorkExecutorFactory computationWorkExecutorFactory = new ComputationWorkExecutorFactory( options, @@ -143,21 +144,22 @@ public static StreamingWorkScheduler create( sampler, streamingCounters.pendingDeltaCounters(), idGenerator, - globalConfigHandle); + globalConfigHandle, + hotKeyLogger, + sideInputStateFetcherFactory); return new StreamingWorkScheduler( - options, clock, + workExecutor, computationWorkExecutorFactory, - SideInputStateFetcherFactory.fromOptions(options), failureTracker, workFailureProcessor, StreamingCommitFinalizer.create(workExecutor, commitFinalizerCleanupExecutor), streamingCounters, - hotKeyLogger, stageInfoMap, sampler, - globalConfigHandle); + globalConfigHandle, + multiKeyExperimentEnabled); } private static long computeShuffleBytesRead(Windmill.WorkItem workItem) { @@ -185,23 +187,18 @@ private static Windmill.WorkItemCommitRequest buildWorkItemTruncationRequest( return outputBuilder.build(); } - /** Sets the stage name and workId of the Thread executing the {@link Work} for logging. */ - private static void setUpWorkLoggingContext(String workLatencyTrackingId, String computationId) { - DataflowWorkerLoggingMDC.setWorkId(workLatencyTrackingId); + private static void setLoggingContextComputation(@Nullable String computationId) { DataflowWorkerLoggingMDC.setStageName(computationId); } - private static String getShuffleTaskStepName(MapTask mapTask) { - // The MapTask instruction is ordered by dependencies, such that the first element is - // always going to be the shuffle task. - return mapTask.getInstructions().get(0).getName(); + private static void setLoggingContextWorkId(@Nullable String workLatencyTrackingId) { + DataflowWorkerLoggingMDC.setWorkId(workLatencyTrackingId); } /** Resets logging context of the Thread executing the {@link Work} for logging. */ - private void resetWorkLoggingContext(String workLatencyTrackingId) { - sampler.resetForWorkId(workLatencyTrackingId); - DataflowWorkerLoggingMDC.setWorkId(null); - DataflowWorkerLoggingMDC.setStageName(null); + private void resetWorkLoggingContext() { + setLoggingContextWorkId(null); + setLoggingContextComputation(null); } /** @@ -219,8 +216,14 @@ public void scheduleWork( computationState.activateWork( ExecutableWork.create( Work.create( - workItem, serializedWorkItemSize, watermarks, processingContext, drainMode, clock), - (work, handle) -> processWork(computationState, work, getWorkStreamLatencies, handle))); + workItem, + serializedWorkItemSize, + watermarks, + processingContext, + drainMode, + clock, + getWorkStreamLatencies), + (work, handle) -> processWork(computationState, work, handle))); } /** Adds any applied finalize ids to the commit finalizer to have their callbacks executed. */ @@ -234,95 +237,66 @@ public void queueAppliedFinalizeIds(ImmutableList appliedFinalizeIds) { * internally if processing fails due to uncaught {@link Exception}(s). * * @implNote This will block the calling thread during execution of user DoFns. - * @param handle handled to pass to BoundedQueueExecutor.pollWork, currently unused + * @param handle handled to pass to BoundedQueueExecutor.pollWork */ private void processWork( - ComputationState computationState, - Work work, - ImmutableList getWorkStreamLatencies, - BoundedQueueExecutorWorkHandle handle) { - work.recordGetWorkStreamLatencies(getWorkStreamLatencies); - processWork(computationState, work, handle); - } - - private void processWork( - ComputationState computationState, Work work, BoundedQueueExecutorWorkHandle unusedHandle) { + ComputationState computationState, Work work, BoundedQueueExecutorWorkHandle handle) { Windmill.WorkItem workItem = work.getWorkItem(); String computationId = computationState.getComputationId(); - ByteString key = workItem.getKey(); - work.setProcessingThreadName(Thread.currentThread().getName()); - work.setState(Work.State.PROCESSING); - setUpWorkLoggingContext(work.getLatencyTrackingId(), computationId); LOG.debug("Starting processing for {}:\n{}", computationId, work); + setLoggingContextComputation(computationId); + KeyTransitionListener keyTransitionListener = createKeyTransitionListener(); + keyTransitionListener.onKeyTransition(null, work); // Before any processing starts, call any pending OnCommit callbacks. Nothing that requires // cleanup should be done before this, since we might exit early here. commitFinalizer.finalizeCommits(workItem.getSourceState().getFinalizeIdsList()); + if (workItem.getSourceState().getOnlyFinalize()) { - Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem); - outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true)); - work.setState(Work.State.COMMIT_QUEUED); - work.queueCommit(outputBuilder.build(), computationState); + handleOnlyFinalize(computationState, work, workItem); return; } long processingStartTimeNanos = System.nanoTime(); - MapTask mapTask = computationState.getMapTask(); - StageInfo stageInfo = - stageInfoMap.computeIfAbsent( - mapTask.getStageName(), s -> StageInfo.create(s, mapTask.getSystemName())); + StageInfo stageInfo = getStageInfo(computationState); + List workBatch = null; try { if (work.isFailed()) { throw new WorkItemCancelledException(workItem.getShardingKey()); } - // Execute the user code for the Work. - ExecuteWorkResult executeWorkResult = executeWork(work, stageInfo, computationState); - Windmill.WorkItemCommitRequest.Builder commitRequest = executeWorkResult.commitWorkRequest(); + // Execute the user code for the Work batch. + ExecuteWorkResult executeWorkResult = + executeWork(work, stageInfo, computationState, handle, keyTransitionListener); + workBatch = executeWorkResult.workBatch(); + List workItemCommits = executeWorkResult.workItemCommits(); + + commitFinalizer.cacheCommitFinalizers(executeWorkResult.accumulatedCallbacks()); - // Validate the commit request, possibly requesting truncation if the commitSize is too large. - Windmill.WorkItemCommitRequest validatedCommitRequest = - validateCommitRequestSize(commitRequest.build(), computationId, workItem); + commitWorkBatch(computationState, workBatch, workItemCommits); - // Queue the commit. - work.queueCommit(validatedCommitRequest, computationState); - recordProcessingStats(commitRequest, workItem, executeWorkResult); - LOG.debug("Processing done for work token: {}", workItem.getWorkToken()); + recordProcessingStats(workBatch, workItemCommits, executeWorkResult.stateBytesRead()); + LOG.debug("Processing done for work batch size: {}", workBatch.size()); } catch (Throwable t) { - // OutOfMemoryError that are caught will be rethrown and trigger jvm termination. - try { - workFailureProcessor.logAndProcessFailure( - computationId, - ExecutableWork.create(work, (retry, h) -> processWork(computationState, retry, h)), - t, - invalidWork -> - computationState.completeWorkAndScheduleNextWorkForKey( - invalidWork.getShardedKey(), invalidWork.id())); - } catch (OutOfMemoryError oom) { - throw oom; - } catch (Throwable t2) { - LOG.warn("Failed to process work failure safely for work {}", work.id(), t2); - throw ExceptionUtils.safeWrapThrowableAsException(t2); - } + handleProcessWorkFailure(computationState, handle.getWorkBatch(), computationId, work, t); } finally { // Update total processing time counters. Updating in finally clause ensures that // work items causing exceptions are also accounted in time spent. - long processingTimeMsecs = - TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos); - stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs); - - // Attribute all the processing to timers if the work item contains any timers. - // Tests show that work items rarely contain both timers and message bundles. It should - // be a fairly close approximation. - // Another option: Derive time split between messages and timers based on recent totals. - // either here or in DFE. - if (work.getWorkItem().hasTimers()) { - stageInfo.timerProcessingMsecs().addValue(processingTimeMsecs); + recordProcessingTime( + stageInfo, + workBatch != null ? workBatch : ImmutableList.of(work), + processingStartTimeNanos); + + resetWorkLoggingContext(); + sampler.resetForWorkId(work.getLatencyTrackingId()); + if (workBatch != null) { + for (Work w : workBatch) { + w.setProcessingThreadName(""); + } + } else { + work.setProcessingThreadName(""); } - - resetWorkLoggingContext(work.getLatencyTrackingId()); - work.setProcessingThreadName(""); } } @@ -354,27 +328,38 @@ private Windmill.WorkItemCommitRequest validateCommitRequestSize( } private void recordProcessingStats( - Windmill.WorkItemCommitRequest.Builder outputBuilder, - Windmill.WorkItem workItem, - ExecuteWorkResult executeWorkResult) { - // Compute shuffle and state byte statistics these will be flushed asynchronously. - long stateBytesWritten = - outputBuilder - .clearOutputMessages() - .clearPerWorkItemLatencyAttributions() - .build() - .getSerializedSize(); - - streamingCounters.windmillShuffleBytesRead().addValue(computeShuffleBytesRead(workItem)); - streamingCounters.windmillStateBytesRead().addValue(executeWorkResult.stateBytesRead()); - streamingCounters.windmillStateBytesWritten().addValue(stateBytesWritten); + List workBatch, + List workItemCommits, + long totalStateBytesRead) { + long totalStateBytesWritten = 0; + long totalShuffleBytesRead = 0; + Preconditions.checkState(workBatch.size() == workItemCommits.size()); + for (int i = 0; i < workBatch.size(); i++) { + Windmill.WorkItem workItem = workBatch.get(i).getWorkItem(); + Windmill.WorkItemCommitRequest commit = workItemCommits.get(i); + // Compute shuffle and state byte statistics these will be flushed asynchronously. + long stateBytesWritten = + commit + .toBuilder() + .clearOutputMessages() + .clearPerWorkItemLatencyAttributions() + .build() + .getSerializedSize(); + totalStateBytesWritten += stateBytesWritten; + totalShuffleBytesRead += computeShuffleBytesRead(workItem); + } + streamingCounters.windmillShuffleBytesRead().addValue(totalShuffleBytesRead); + streamingCounters.windmillStateBytesRead().addValue(totalStateBytesRead); + streamingCounters.windmillStateBytesWritten().addValue(totalStateBytesWritten); } private ExecuteWorkResult executeWork( - Work work, StageInfo stageInfo, ComputationState computationState) throws Exception { - Windmill.WorkItem workItem = work.getWorkItem(); - ByteString key = workItem.getKey(); - Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem); + Work work, + StageInfo stageInfo, + ComputationState computationState, + BoundedQueueExecutorWorkHandle handle, + KeyTransitionListener keyTransitionListener) + throws Exception { ComputationWorkExecutor computationWorkExecutor = computationState .acquireComputationWorkExecutor() @@ -384,90 +369,209 @@ private ExecuteWorkResult executeWork( stageInfo, computationState, work.getLatencyTrackingId())); try { - WindmillStateReader stateReader = work.createWindmillStateReader(); - SideInputStateFetcher localSideInputStateFetcher = - sideInputStateFetcherFactory.createSideInputStateFetcher(work::fetchSideInput); - - // If the read output KVs, then we can decode Windmill's byte key into userland - // key object and provide it to the execution context for use with per-key state. - // Otherwise, we pass null. - // - // The coder type that will be present is: - // WindowedValueCoder(TimerOrElementCoder(KvCoder)) - Optional> keyCoder = computationWorkExecutor.keyCoder(); - @SuppressWarnings("deprecation") - @Nullable - final Object executionKey = - !keyCoder.isPresent() ? null : keyCoder.get().decode(key.newInput(), Coder.Context.OUTER); - - if (workItem.hasHotKeyInfo()) { - Windmill.HotKeyInfo hotKeyInfo = workItem.getHotKeyInfo(); - Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 1000); - - String stepName = getShuffleTaskStepName(computationState.getMapTask()); - if (executionKey != null - && (options.isHotKeyLoggingEnabled() - || hasExperiment(options, "enable_hot_key_logging")) - && keyCoder.isPresent()) { - hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, executionKey); - } else { - hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge); - } - } + StreamingModeExecutionContext context = computationWorkExecutor.context(); // Blocks while executing work. - computationWorkExecutor.executeWork( - executionKey, work, stateReader, localSideInputStateFetcher, outputBuilder); + computationWorkExecutor.executeWork(work, workExecutor, handle, keyTransitionListener); + + List workBatch; + List workItemCommits; + Map> accumulatedCallbacks; + long stateBytesRead; + { + if (context.workIsFailed()) { + throw new WorkItemCancelledException(work.getWorkItem().getShardingKey()); + } - if (work.isFailed()) { - throw new WorkItemCancelledException(workItem.getShardingKey()); - } + // Retrieve executed works, work item commits, and accumulated callbacks from execution + // context + workBatch = context.getExecutedWorks(); + workItemCommits = context.getWorkItemCommits(); + accumulatedCallbacks = context.getAccumulatedCallbacks(); + stateBytesRead = context.getStateBytesRead(); - // Reports source bytes processed to WorkItemCommitRequest if available. - try { - long sourceBytesProcessed = - computationWorkExecutor.computeSourceBytesProcessed( - computationState.sourceBytesProcessCounterName()); - outputBuilder.setSourceBytesProcessed(sourceBytesProcessed); - } catch (Exception e) { - LOG.error("{}", e.toString()); + context.clear(); // Don't use context after this. } - - commitFinalizer.cacheCommitFinalizers(computationWorkExecutor.context().flushState()); - // Release the execution state for another thread to use. computationState.releaseComputationWorkExecutor(computationWorkExecutor); computationWorkExecutor = null; - work.setState(Work.State.COMMIT_QUEUED); - outputBuilder.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler)); - return ExecuteWorkResult.create( - outputBuilder, stateReader.getBytesRead() + localSideInputStateFetcher.getBytesRead()); + workBatch, workItemCommits, accumulatedCallbacks, stateBytesRead); } catch (Throwable t) { if (computationWorkExecutor != null) { // If processing failed due to a thrown exception, close the executionState. Do not // return/release the executionState back to computationState as that will lead to this // executionState instance being reused. - LOG.debug("Invalidating executor after work item {} failed", workItem.getWorkToken(), t); + LOG.debug( + "Invalidating executor after work item {} failed", + work.getWorkItem().getWorkToken(), + t); computationWorkExecutor.invalidate(); } - // Re-throw the exception, it will be caught and handled by workFailureProcessor downstream. throw t; } } + private void handleOnlyFinalize( + ComputationState computationState, Work work, Windmill.WorkItem workItem) { + Windmill.WorkItemCommitRequest.Builder outputBuilder = + initializeOutputBuilder(workItem.getKey(), workItem); + outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true)); + work.setState(Work.State.COMMIT_QUEUED); + work.queueCommit(outputBuilder.build(), computationState); + } + + private StageInfo getStageInfo(ComputationState computationState) { + MapTask mapTask = computationState.getMapTask(); + return stageInfoMap.computeIfAbsent( + mapTask.getStageName(), s -> StageInfo.create(s, mapTask.getSystemName())); + } + + private void commitWorkBatch( + ComputationState computationState, + List workBatch, + List workItemCommits) { + if (workBatch.isEmpty()) { + return; + } + if (workBatch.size() > 1 || multiKeyExperimentEnabled) { + commitMultiKeyWorkBatch(computationState, workBatch, workItemCommits); + } else { + commitSingleKeyWork(computationState, workBatch.get(0), workItemCommits.get(0)); + } + } + + private void commitMultiKeyWorkBatch( + ComputationState computationState, + List workBatch, + List workItemCommits) { + Windmill.MultiKeyWorkItemCommitRequest.Builder multiKeyBuilder = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder(); + + Work primaryWork = workBatch.get(0); + Work.KeyGroup keyGroup = primaryWork.getKeyGroup(); + multiKeyBuilder.setKeyGroup( + Windmill.Uint128Proto.newBuilder().setHigh(keyGroup.high()).setLow(keyGroup.low()).build()); + + for (int i = 0; i < workBatch.size(); i++) { + // TODO: Add commit size validation + Windmill.WorkItemCommitRequest commit = workItemCommits.get(i); + Work w = workBatch.get(i); + multiKeyBuilder.addRequests( + commit + .toBuilder() + .addAllPerWorkItemLatencyAttributions(w.getLatencyAttributions(sampler)) + .build()); + } + + // Transition states of all completed works in the batch to COMMIT_QUEUED and submit + for (Work w : workBatch) { + w.setState(Work.State.COMMIT_QUEUED); + } + + // Package and submit the commit batch transactionally + primaryWork + .workCommitter() + .accept( + Commit.createMultiKey( + multiKeyBuilder.build(), computationState, ImmutableList.copyOf(workBatch))); + } + + private void commitSingleKeyWork( + ComputationState computationState, Work work, Windmill.WorkItemCommitRequest commitRequest) { + // Validate the commit request, possibly requesting truncation if the commitSize is too large. + Windmill.WorkItemCommitRequest validatedCommitRequest = + validateCommitRequestSize( + commitRequest, computationState.getComputationId(), work.getWorkItem()); + work.setState(Work.State.COMMIT_QUEUED); + validatedCommitRequest = + validatedCommitRequest + .toBuilder() + .addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler)) + .build(); + work.queueCommit(validatedCommitRequest, computationState); + } + + private void handleProcessWorkFailure( + ComputationState computationState, + List failedBatch, + String computationId, + Work primaryWork, + Throwable t) { + try { + List executableWorks = new ArrayList<>(); + for (Work w : failedBatch) { + executableWorks.add( + ExecutableWork.create(w, (retry, h) -> processWork(computationState, retry, h))); + } + + workFailureProcessor.logAndProcessFailureBatch( + computationId, + executableWorks, + t, + invalidWork -> + computationState.completeWorkAndScheduleNextWorkForKey( + invalidWork.getShardedKey(), invalidWork.id())); + } catch (OutOfMemoryError oom) { + throw oom; + } catch (Throwable t2) { + LOG.warn("Failed to process work failure safely for work {}", primaryWork.id(), t2); + throw ExceptionUtils.safeWrapThrowableAsException(t2); + } + } + + private void recordProcessingTime( + StageInfo stageInfo, List workBatch, long processingStartTimeNanos) { + long processingTimeMsecs = + TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos); + stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs); + if (anyWorkHasTimers(workBatch)) { + // Attribute all the processing to timers if the work item contains any timers. + // Tests show that work items rarely contain both timers and message bundles. It should + // be a fairly close approximation. + // Another option: Derive time split between messages and timers based on recent totals. + // either here or in DFE. + stageInfo.timerProcessingMsecs().addValue(processingTimeMsecs); + } + } + + private static boolean anyWorkHasTimers(List works) { + return works.stream().anyMatch(w -> w.getWorkItem().hasTimers()); + } + + private KeyTransitionListener createKeyTransitionListener() { + return (oldWork, newWork) -> { + newWork.recordGetWorkStreamLatencies(); + newWork.setState(Work.State.PROCESSING); + setLoggingContextWorkId(newWork.getLatencyTrackingId()); + if (oldWork != null) { + newWork.setProcessingThreadName(oldWork.getProcessingThreadName()); + oldWork.setProcessingThreadName(""); + } else { + newWork.setProcessingThreadName(Thread.currentThread().getName()); + } + }; + } + @AutoValue abstract static class ExecuteWorkResult { - - private static ExecuteWorkResult create( - Windmill.WorkItemCommitRequest.Builder commitWorkRequest, long stateBytesRead) { + static ExecuteWorkResult create( + List workBatch, + List workItemCommits, + Map> accumulatedCallbacks, + long stateBytesRead) { return new AutoValue_StreamingWorkScheduler_ExecuteWorkResult( - commitWorkRequest, stateBytesRead); + workBatch, workItemCommits, accumulatedCallbacks, stateBytesRead); } - abstract Windmill.WorkItemCommitRequest.Builder commitWorkRequest(); + abstract List workBatch(); + + abstract List workItemCommits(); + + // Map> + abstract Map> accumulatedCallbacks(); abstract long stateBytesRead(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessor.java index 18c8e9b8d83c..15ec1e0c2cf3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessor.java @@ -17,13 +17,12 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures; +import java.util.List; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.function.Supplier; import javax.annotation.Nullable; import javax.annotation.concurrent.ThreadSafe; -import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException; -import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.status.LastExceptionDataProvider; import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Work; @@ -99,28 +98,41 @@ private static boolean isOutOfMemoryError(@Nullable Throwable t) { return false; } - /** - * Processes failures caused by thrown exceptions that occur during execution of {@link Work}. May - * attempt to retry execution of the {@link Work} or drop it if it is invalid. - */ - public void logAndProcessFailure( + public void logAndProcessFailureBatch( String computationId, - ExecutableWork executableWork, + List executableWorks, Throwable t, Consumer onInvalidWork) throws Throwable { - switch (evaluateRetry(computationId, executableWork.work(), t)) { - case DO_NOT_RETRY: - // Consider the item invalid. It will eventually be retried by Windmill if it still needs to - // be processed. - onInvalidWork.accept(executableWork.work()); - break; - case RETRY_LOCALLY: - // Try again after some delay and at the end of the queue to avoid a tight loop. - executeWithDelay(retryLocallyDelayMs, executableWork); - break; - case RETHROW_THROWABLE: - throw t; + List worksToRetryLocally = new java.util.ArrayList<>(); + + for (ExecutableWork executableWork : executableWorks) { + switch (evaluateRetry(computationId, executableWork.work(), t)) { + case DO_NOT_RETRY: + // Consider the item invalid. It will eventually be retried by Windmill if it still needs + // to + // be processed. + onInvalidWork.accept(executableWork.work()); + break; + case RETRY_LOCALLY: + // Try again after some delay and at the end of the queue to avoid a tight loop. + worksToRetryLocally.add(executableWork); + break; + case RETHROW_THROWABLE: + throw t; + } + } + + executeWithDelay(worksToRetryLocally); + } + + private void executeWithDelay(List worksToRetryLocally) { + if (!worksToRetryLocally.isEmpty()) { + // Sleep ONCE for the entire batch delay to avoid sequential thread blocks + Uninterruptibles.sleepUninterruptibly(retryLocallyDelayMs, TimeUnit.MILLISECONDS); + for (ExecutableWork ew : worksToRetryLocally) { + workUnitExecutor.forceExecute(ew, ew.work().getSerializedWorkItemSize()); + } } } @@ -131,12 +143,6 @@ private String tryToDumpHeap() { .orElseGet(() -> "not written"); } - private void executeWithDelay(long delayMs, ExecutableWork executableWork) { - Uninterruptibles.sleepUninterruptibly(delayMs, TimeUnit.MILLISECONDS); - workUnitExecutor.forceExecute( - executableWork, executableWork.work().getSerializedWorkItemSize()); - } - private enum RetryEvaluation { DO_NOT_RETRY, RETRY_LOCALLY, @@ -144,24 +150,16 @@ private enum RetryEvaluation { } private RetryEvaluation evaluateRetry(String computationId, Work work, Throwable t) { - @Nullable final Throwable cause = t.getCause(); - Throwable parsedException = (t instanceof UserCodeException && cause != null) ? cause : t; - if (KeyTokenInvalidException.isKeyTokenInvalidException(parsedException)) { - LOG.debug( - "Execution of work for computation '{}' on sharding key '{}' failed due to token expiration. " - + "Work will not be retried locally.", - computationId, - work.getWorkItem().getShardingKey()); - return RetryEvaluation.DO_NOT_RETRY; - } - if (WorkItemCancelledException.isWorkItemCancelledException(parsedException)) { + if (work.isFailed()) { LOG.debug( "Execution of work for computation '{}' on sharding key '{}' failed. " - + "Work will not be retried locally.", + + "Work is already marked as failed, not retrying locally.", computationId, work.getWorkItem().getShardingKey()); return RetryEvaluation.DO_NOT_RETRY; } + @Nullable final Throwable cause = t.getCause(); + Throwable parsedException = (t instanceof UserCodeException && cause != null) ? cause : t; LastExceptionDataProvider.reportException(parsedException); LOG.debug("Failed work: {}", work); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index 5be8ec0a6c72..eec77ccf435b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -29,7 +29,6 @@ import java.util.ArrayList; import java.util.Collection; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -37,6 +36,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -89,6 +89,8 @@ public final class FakeWindmillServer extends WindmillServerStub { private final Map streamingCommitsToOffer; // Keys are work tokens. private final Map commitsReceived; + private final List multiKeyCommitsReceived = + new CopyOnWriteArrayList<>(); private final ArrayList statsReceived; private final LinkedBlockingQueue exceptions; private final AtomicInteger expectedExceptionCount; @@ -118,7 +120,7 @@ public FakeWindmillServer( commitsToOffer = new ResponseQueue() .returnByDefault(CommitWorkResponse.getDefaultInstance()); - streamingCommitsToOffer = new HashMap<>(); + streamingCommitsToOffer = new ConcurrentHashMap<>(); commitsReceived = new ConcurrentHashMap<>(); exceptions = new LinkedBlockingQueue<>(); expectedExceptionCount = new AtomicInteger(); @@ -400,6 +402,7 @@ public void shutdown() {} public RequestBatcher batcher() { return new RequestBatcher() { final List requests = new ArrayList<>(); + final List multiKeyRequests = new ArrayList<>(); @Override public boolean commitWorkItem( @@ -423,6 +426,18 @@ public boolean commitWorkItem( return true; } + @Override + public boolean commitMultiKeyWorkItem( + String computation, + Windmill.MultiKeyWorkItemCommitRequest request, + Consumer onDone) { + LOG.debug("commitWorkStream::commitMultiKeyWorkItem: {}", request); + if (multiKeyRequests.size() > 5) return false; + multiKeyRequests.add(new MultiKeyRequestAndDone(request, onDone)); + flush(); + return true; + } + @Override public void flush() { for (RequestAndDone elem : requests) { @@ -445,6 +460,37 @@ public void flush() { .orElse(Windmill.CommitStatus.OK)); } requests.clear(); + + for (MultiKeyRequestAndDone elem : multiKeyRequests) { + if (dropStreamingCommits) { + for (WorkItemCommitRequest workRequest : elem.request.getRequestsList()) { + droppedStreamingCommits.put(workRequest.getWorkToken(), elem.onDone); + } + continue; + } + + multiKeyCommitsReceived.add(elem.request); + for (WorkItemCommitRequest workRequest : elem.request.getRequestsList()) { + commitsReceived.put(workRequest.getWorkToken(), workRequest); + } + + // Determine status for the batch. + // Default to OK, but if any of the works in the batch has an offered status, use it. + Windmill.CommitStatus status = Windmill.CommitStatus.OK; + for (WorkItemCommitRequest workRequest : elem.request.getRequestsList()) { + Windmill.CommitStatus offeredStatus = + streamingCommitsToOffer.remove( + WorkId.builder() + .setWorkToken(workRequest.getWorkToken()) + .setCacheToken(workRequest.getCacheToken()) + .build()); + if (offeredStatus != null) { + status = offeredStatus; + } + } + elem.onDone.accept(status); + } + multiKeyRequests.clear(); } class RequestAndDone { @@ -456,6 +502,18 @@ class RequestAndDone { this.onDone = onDone; } } + + class MultiKeyRequestAndDone { + final Consumer onDone; + final Windmill.MultiKeyWorkItemCommitRequest request; + + MultiKeyRequestAndDone( + Windmill.MultiKeyWorkItemCommitRequest request, + Consumer onDone) { + this.request = request; + this.onDone = onDone; + } + } }; } @@ -518,6 +576,15 @@ public Map waitForAndGetCommits(int numCommits) { public void clearCommitsReceived() { commitsRequested = 0; commitsReceived.clear(); + multiKeyCommitsReceived.clear(); + } + + public List getMultiKeyCommitsReceived() { + return multiKeyCommitsReceived; + } + + public void clearMultiKeyCommitsReceived() { + multiKeyCommitsReceived.clear(); } public ConcurrentHashMap> waitForDroppedCommits( diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidExceptionTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidExceptionTest.java deleted file mode 100644 index 1eb2871e8cd3..000000000000 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/KeyTokenInvalidExceptionTest.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.dataflow.worker; - -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests for {@link KeyTokenInvalidException}. */ -@RunWith(JUnit4.class) -public final class KeyTokenInvalidExceptionTest { - @Test - public void testIsKeyTokenInvalidException() throws Exception { - KeyTokenInvalidException exception = new KeyTokenInvalidException("test"); - RuntimeException keyTokenCauseException = new RuntimeException("key token cause", exception); - assertTrue(KeyTokenInvalidException.isKeyTokenInvalidException(exception)); - assertTrue(KeyTokenInvalidException.isKeyTokenInvalidException(keyTokenCauseException)); - assertFalse( - KeyTokenInvalidException.isKeyTokenInvalidException(new RuntimeException("non key token"))); - } -} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index 5bcdffcc2564..dbb1cc45e1b8 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -420,6 +420,7 @@ private ParallelInstruction makeWindowingSourceInstruction(Coder coder) { CloudObjects.asCloudObject(IntervalWindowCoder.of(), /* sdkComponents= */ null))); return new ParallelInstruction() + .setName(DEFAULT_SOURCE_SYSTEM_NAME) .setSystemName(DEFAULT_SOURCE_SYSTEM_NAME) .setOriginalName(DEFAULT_SOURCE_ORIGINAL_NAME) .setRead( @@ -439,6 +440,7 @@ private ParallelInstruction makeWindowingSourceInstruction(Coder coder) { private ParallelInstruction makeSourceInstruction(Coder coder) { return new ParallelInstruction() + .setName(DEFAULT_SOURCE_SYSTEM_NAME) .setSystemName(DEFAULT_SOURCE_SYSTEM_NAME) .setOriginalName(DEFAULT_SOURCE_ORIGINAL_NAME) .setRead( @@ -527,6 +529,7 @@ private ParallelInstruction makeSinkInstruction( CloudObject spec = CloudObject.forClass(WindmillSink.class); addString(spec, "stream_id", streamId); return new ParallelInstruction() + .setName(streamId) .setSystemName(DEFAULT_SINK_SYSTEM_NAME) .setOriginalName(DEFAULT_SINK_ORIGINAL_NAME) .setWrite( @@ -571,11 +574,16 @@ private Windmill.GetWorkResponse buildInput(String input, byte[] metadata) throw Windmill.GetWorkResponse.Builder builder = Windmill.GetWorkResponse.newBuilder(); TextFormat.merge(input, builder); if (metadata != null) { - Windmill.InputMessageBundle.Builder messageBundleBuilder = - builder.getWorkBuilder(0).getWorkBuilder(0).getMessageBundlesBuilder(0); - for (Windmill.Message.Builder messageBuilder : - messageBundleBuilder.getMessagesBuilderList()) { - messageBuilder.setMetadata(addPaneTag(PaneInfo.NO_FIRING, metadata)); + for (Windmill.ComputationWorkItems.Builder compBuilder : builder.getWorkBuilderList()) { + for (Windmill.WorkItem.Builder workBuilder : compBuilder.getWorkBuilderList()) { + for (Windmill.InputMessageBundle.Builder messageBundleBuilder : + workBuilder.getMessageBundlesBuilderList()) { + for (Windmill.Message.Builder messageBuilder : + messageBundleBuilder.getMessagesBuilderList()) { + messageBuilder.setMetadata(addPaneTag(PaneInfo.NO_FIRING, metadata)); + } + } + } } } @@ -893,7 +901,7 @@ private ByteString addPaneTag(PaneInfo paneInfo, byte[] windowBytes) throws IOEx } private DataflowWorkerHarnessOptions createTestingPipelineOptions(String... args) { - List argsList = Lists.newArrayList(args); + List argsList = new ArrayList<>(Arrays.asList(args)); if (streamingEngine) { argsList.add("--experiments=enable_streaming_engine"); } @@ -1244,9 +1252,8 @@ public void testNumberOfWorkerHarnessThreadsIsHonored() throws Exception { } @Test - public void testKeyTokenInvalidException() throws Exception { - if (streamingEngine) { - // TODO: This test needs to be adapted to work with streamingEngine=true. + public void testMultiKeyCommit_success() throws Exception { + if (!streamingEngine) { return; } KvCoder kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()); @@ -1254,30 +1261,359 @@ public void testKeyTokenInvalidException() throws Exception { List instructions = Arrays.asList( makeSourceInstruction(kvCoder), - makeDoFnInstruction(new KeyTokenInvalidFn(), 0, kvCoder), + makeDoFnInstruction(new WorkDoFn(), 0, kvCoder), makeSinkInstruction(kvCoder, 1)); + StreamingDataflowWorker worker = + makeWorker( + defaultWorkerParams( + "--experiments=unstable_enable_multi_key_bundle,windmill_max_key_group_batch_time_ms=50000", + "--numberOfWorkerHarnessThreads=1") + .setLocalRetryTimeoutMs(100) + .setInstructions(instructions) + .build()); + worker.start(); + + String batchInputText = + "work {" + + " computation_id: \"" + + DEFAULT_COMPUTATION_ID + + "\"" + + " input_data_watermark: 0" + + " work {" + + " key: \"key1\"" + + " sharding_key: 1" + + " work_token: 1" + + " cache_token: 2" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data1\"" + + " }" + + " }" + + " }" + + " work {" + + " key: \"key2\"" + + " sharding_key: 2" + + " work_token: 2" + + " cache_token: 3" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data2\"" + + " }" + + " }" + + " }" + + " work {" + + " key: \"key3\"" + + " sharding_key: 3" + + " work_token: 3" + + " cache_token: 4" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data3\"" + + " }" + + " }" + + " }" + + "}"; + Windmill.GetWorkResponse batchInput = + buildInput( + batchInputText, + CoderUtils.encodeToByteArray( + CollectionCoder.of(IntervalWindow.getCoder()), + Collections.singletonList(DEFAULT_WINDOW))); + server - .whenGetWorkCalled() - .thenReturn(makeInput(0, 0, DEFAULT_KEY_STRING, DEFAULT_SHARDING_KEY)); + .whenGetDataCalled() + .answerByDefault( + request -> { + Windmill.GetDataResponse.Builder builder = Windmill.GetDataResponse.newBuilder(); + for (ComputationGetDataRequest compRequest : request.getRequestsList()) { + ComputationGetDataResponse.Builder compBuilder = + builder.addDataBuilder().setComputationId(compRequest.getComputationId()); + for (KeyedGetDataRequest keyRequest : compRequest.getRequestsList()) { + KeyedGetDataResponse.Builder keyBuilder = + compBuilder + .addDataBuilder() + .setKey(keyRequest.getKey()) + .setShardingKey(keyRequest.getShardingKey()); + keyBuilder.addAllValues(keyRequest.getValuesToFetchList()); + keyBuilder.addAllBags(keyRequest.getBagsToFetchList()); + keyBuilder.addAllWatermarkHolds(keyRequest.getWatermarkHoldsToFetchList()); + } + } + return builder.build(); + }); + + server.whenGetWorkCalled().thenReturn(batchInput); + + Map result = server.waitForAndGetCommits(3); + + assertEquals(3, result.size()); + + List multiKeyCommits = + server.getMultiKeyCommitsReceived(); + assertEquals(1, multiKeyCommits.size()); + Windmill.MultiKeyWorkItemCommitRequest multiKeyCommit = multiKeyCommits.get(0); + assertEquals(3, multiKeyCommit.getRequestsCount()); + assertEquals(1, multiKeyCommit.getRequests(0).getWorkToken()); + assertEquals(2, multiKeyCommit.getRequests(1).getWorkToken()); + assertEquals(3, multiKeyCommit.getRequests(2).getWorkToken()); + + worker.stop(); + } + + @Test + public void testMultiKeyCommit_elementFailure() throws Exception { + if (!streamingEngine) { + return; + } + KvCoder kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()); + + List instructions = + Arrays.asList( + makeSourceInstruction(kvCoder), + makeDoFnInstruction(new WorkDoFn(), 0, kvCoder), + makeSinkInstruction(kvCoder, 1)); StreamingDataflowWorker worker = - makeWorker(defaultWorkerParams().setInstructions(instructions).publishCounters().build()); + makeWorker( + defaultWorkerParams( + "--experiments=unstable_enable_multi_key_bundle,windmill_max_key_group_batch_time_ms=5000", + "--numberOfWorkerHarnessThreads=1") + .setLocalRetryTimeoutMs(100) + .setInstructions(instructions) + .build()); worker.start(); - server.waitForEmptyWorkQueue(); + String batchInputText = + "work {" + + " computation_id: \"" + + DEFAULT_COMPUTATION_ID + + "\"" + + " input_data_watermark: 0" + + " work {" + + " key: \"key1\"" + + " sharding_key: 1" + + " work_token: 1" + + " cache_token: 2" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data1\"" + + " }" + + " }" + + " }" + + " work {" + + " key: \"key2\"" + + " sharding_key: 2" + + " work_token: 2" + + " cache_token: 3" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data2\"" + + " }" + + " }" + + " }" + + " work {" + + " key: \"key3\"" + + " sharding_key: 3" + + " work_token: 3" + + " cache_token: 4" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data3\"" + + " }" + + " }" + + " }" + + "}"; + Windmill.GetWorkResponse batchInput = + buildInput( + batchInputText, + CoderUtils.encodeToByteArray( + CollectionCoder.of(IntervalWindow.getCoder()), + Collections.singletonList(DEFAULT_WINDOW))); server - .whenGetWorkCalled() - .thenReturn(makeInput(1, 0, DEFAULT_KEY_STRING, DEFAULT_SHARDING_KEY)); + .whenGetDataCalled() + .answerByDefault( + request -> { + Windmill.GetDataResponse.Builder builder = Windmill.GetDataResponse.newBuilder(); + for (ComputationGetDataRequest compRequest : request.getRequestsList()) { + ComputationGetDataResponse.Builder compBuilder = + builder.addDataBuilder().setComputationId(compRequest.getComputationId()); + for (KeyedGetDataRequest keyRequest : compRequest.getRequestsList()) { + KeyedGetDataResponse.Builder keyBuilder = + compBuilder + .addDataBuilder() + .setKey(keyRequest.getKey()) + .setShardingKey(keyRequest.getShardingKey()); + if (keyRequest.getWorkToken() == 2) { + keyBuilder.setFailed(true); + } else { + keyBuilder.addAllValues(keyRequest.getValuesToFetchList()); + keyBuilder.addAllBags(keyRequest.getBagsToFetchList()); + keyBuilder.addAllWatermarkHolds(keyRequest.getWatermarkHoldsToFetchList()); + } + } + } + return builder.build(); + }); + + server.whenGetWorkCalled().thenReturn(batchInput); + + Map result = server.waitForAndGetCommits(2); + + assertTrue(result.containsKey(1L)); + assertTrue(result.containsKey(3L)); + assertFalse(result.containsKey(2L)); + + List multiKeyCommits = + server.getMultiKeyCommitsReceived(); + assertEquals(1, multiKeyCommits.size()); + Windmill.MultiKeyWorkItemCommitRequest multiKeyCommit = multiKeyCommits.get(0); + assertEquals(2, multiKeyCommit.getRequestsCount()); + assertEquals(3, multiKeyCommit.getRequests(0).getWorkToken()); + assertEquals(1, multiKeyCommit.getRequests(1).getWorkToken()); + + worker.stop(); + } + + @Test + public void testCompleteCommit_retryableFailureTriggersReExecution() throws Exception { + if (!streamingEngine) { + return; + } + KvCoder kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()); + + List instructions = + Arrays.asList( + makeSourceInstruction(kvCoder), + makeDoFnInstruction(new WorkDoFn(), 0, kvCoder), + makeSinkInstruction(kvCoder, 1)); + + StreamingDataflowWorker worker = + makeWorker( + defaultWorkerParams( + "--experiments=unstable_enable_multi_key_bundle,max_key_group_batch_time_ms=5000", + "--numberOfWorkerHarnessThreads=1") + .setLocalRetryTimeoutMs(100) + .setInstructions(instructions) + .build()); + worker.start(); + + String batchInputText = + "work {" + + " computation_id: \"" + + DEFAULT_COMPUTATION_ID + + "\"" + + " input_data_watermark: 0" + + " work {" + + " key: \"key1\"" + + " sharding_key: 1" + + " work_token: 1" + + " cache_token: 2" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data1\"" + + " }" + + " }" + + " }" + + " work {" + + " key: \"key2\"" + + " sharding_key: 2" + + " work_token: 2" + + " cache_token: 3" + + " key_group { high: 0 low: 1 }" + + " message_bundles {" + + " source_computation_id: \"" + + DEFAULT_SOURCE_COMPUTATION_ID + + "\"" + + " messages {" + + " timestamp: 0" + + " data: \"data2\"" + + " }" + + " }" + + " }" + + "}"; + Windmill.GetWorkResponse batchInput = + buildInput( + batchInputText, + CoderUtils.encodeToByteArray( + CollectionCoder.of(IntervalWindow.getCoder()), + Collections.singletonList(DEFAULT_WINDOW))); + + server + .whenGetDataCalled() + .answerByDefault( + request -> { + Windmill.GetDataResponse.Builder builder = Windmill.GetDataResponse.newBuilder(); + for (ComputationGetDataRequest compRequest : request.getRequestsList()) { + ComputationGetDataResponse.Builder compBuilder = + builder.addDataBuilder().setComputationId(compRequest.getComputationId()); + for (KeyedGetDataRequest keyRequest : compRequest.getRequestsList()) { + KeyedGetDataResponse.Builder keyBuilder = + compBuilder + .addDataBuilder() + .setKey(keyRequest.getKey()) + .setShardingKey(keyRequest.getShardingKey()); + if (keyRequest.getWorkToken() == 2) { + keyBuilder.setFailed(true); + } else { + keyBuilder.addAllValues(keyRequest.getValuesToFetchList()); + keyBuilder.addAllBags(keyRequest.getBagsToFetchList()); + keyBuilder.addAllWatermarkHolds(keyRequest.getWatermarkHoldsToFetchList()); + } + } + } + return builder.build(); + }); + + server.whenGetWorkCalled().thenReturn(batchInput); Map result = server.waitForAndGetCommits(1); - assertEquals( - makeExpectedOutput(1, 0, DEFAULT_KEY_STRING, DEFAULT_SHARDING_KEY, DEFAULT_KEY_STRING) - .build(), - removeDynamicFields(result.get(1L))); - assertEquals(1, result.size()); + assertTrue(result.containsKey(1L)); + assertFalse(result.containsKey(2L)); + + List multiKeyCommits = + server.getMultiKeyCommitsReceived(); + assertEquals(1, multiKeyCommits.size()); + Windmill.MultiKeyWorkItemCommitRequest multiKeyCommit = multiKeyCommits.get(0); + assertEquals(1, multiKeyCommit.getRequestsCount()); + assertEquals(1, multiKeyCommit.getRequests(0).getWorkToken()); worker.stop(); } @@ -1327,7 +1663,7 @@ public void testKeyCommitTooLargeException() throws Exception { makeExpectedTruncationRequestOutput( 1, "large_key", DEFAULT_SHARDING_KEY, largeCommit.getEstimatedWorkItemCommitBytes()) .build(), - largeCommit); + removeDynamicFields(largeCommit)); // Check this explicitly since the estimated commit bytes weren't actually // checked against an expected value in the previous step @@ -2495,6 +2831,7 @@ private List makeUnboundedSourcePipeline( return Arrays.asList( new ParallelInstruction() + .setName("Read") .setSystemName("Read") .setOriginalName("OriginalReadName") .setRead( @@ -3511,8 +3848,8 @@ public void testExceptionInvalidatesCache() throws Exception { } // Ensure that the invalidated dofn had tearDown called on them. - assertEquals(1, TestExceptionInvalidatesCacheFn.tearDownCallCount.get()); - assertEquals(2, TestExceptionInvalidatesCacheFn.setupCallCount.get()); + assertEquals(2, TestExceptionInvalidatesCacheFn.tearDownCallCount.get()); + assertEquals(3, TestExceptionInvalidatesCacheFn.setupCallCount.get()); worker.stop(); } @@ -3954,11 +4291,16 @@ public void testDoFnLatencyBreakdownsReportedOnCommit() throws Exception { LatencyAttribution.newBuilder().setState(State.ACTIVE).setTotalDurationMillis(100); for (LatencyAttribution la : commit.getPerWorkItemLatencyAttributionsList()) { if (la.getState() == State.ACTIVE) { - assertThat(la.getActiveLatencyBreakdownCount(), equalTo(1)); - assertThat( - la.getActiveLatencyBreakdown(0).getUserStepName(), equalTo(DEFAULT_PARDO_USER_NAME)); - Assert.assertTrue(la.getActiveLatencyBreakdown(0).hasProcessingTimesDistribution()); - Assert.assertFalse(la.getActiveLatencyBreakdown(0).hasActiveMessageMetadata()); + LatencyAttribution.ActiveLatencyBreakdown pardoBreakdown = null; + for (LatencyAttribution.ActiveLatencyBreakdown lb : la.getActiveLatencyBreakdownList()) { + if (DEFAULT_PARDO_USER_NAME.equals(lb.getUserStepName())) { + pardoBreakdown = lb; + break; + } + } + Assert.assertNotNull("Expected breakdown for " + DEFAULT_PARDO_USER_NAME, pardoBreakdown); + Assert.assertTrue(pardoBreakdown.hasProcessingTimesDistribution()); + Assert.assertFalse(pardoBreakdown.hasActiveMessageMetadata()); } } @@ -4529,18 +4871,19 @@ public void evaluate() throws Throwable { } } - static class KeyTokenInvalidFn extends DoFn, KV> { - - static boolean thrown = false; + static class WorkDoFn extends DoFn, KV> { + @StateId("state") + private final StateSpec> stateSpec = StateSpecs.value(StringUtf8Coder.of()); @ProcessElement - public void processElement(ProcessContext c) { - if (!thrown) { - thrown = true; - throw new KeyTokenInvalidException("key"); - } else { - c.output(c.element()); + public void processElement(ProcessContext c, @StateId("state") ValueState state) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); } + state.read(); + c.output(c.element()); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 13601410bfd9..539a17e97508 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -24,6 +24,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; @@ -57,22 +58,26 @@ import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.NoopProfileScope; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.ProfileScope; +import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; +import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.config.FakeGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.common.worker.WorkExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; -import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillTagEncodingV1; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillTagEncodingV2; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.metrics.MetricsContainer; +import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.transforms.Create; @@ -99,8 +104,7 @@ public class StreamingModeExecutionContextTest { @Rule public transient Timeout globalTimeout = Timeout.seconds(600); - @Mock private SideInputStateFetcher sideInputStateFetcher; - @Mock private WindmillStateReader stateReader; + @Mock private WorkExecutor workExecutor; private static final String COMPUTATION_ID = "computationId"; @@ -112,7 +116,7 @@ public class StreamingModeExecutionContextTest { private FakeGlobalConfigHandle globalConfigHandle; private StreamingModeExecutionContext createExecutionContext( - StreamingGlobalConfigHandle configHandle) { + DataflowWorkerHarnessOptions options, StreamingGlobalConfigHandle configHandle) { CounterSet counterSet = new CounterSet(); ConcurrentHashMap stateNameMap = new ConcurrentHashMap<>(); stateNameMap.put(NameContextsForTests.nameContextForTest().userName(), "testStateFamily"); @@ -136,15 +140,24 @@ private StreamingModeExecutionContext createExecutionContext( executionStateRegistry, configHandle, Long.MAX_VALUE, - /*throwExceptionOnLargeOutput=*/ false); + /*throwExceptionOnLargeOutput=*/ false, + new HotKeyLogger(), + /*hotKeyLoggingEnabled=*/ false, + /*stepName=*/ "stepName", + "sourceBytesProcessCounterName", + options, + SideInputStateFetcherFactory.fromOptions(options)); } @Before public void setUp() { MockitoAnnotations.initMocks(this); options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + options + .as(ExperimentalOptions.class) + .setExperiments(Arrays.asList("unstable_enable_multi_key_bundle")); globalConfigHandle = new FakeGlobalConfigHandle(StreamingGlobalConfig.builder().build()); - executionContext = createExecutionContext(globalConfigHandle); + executionContext = createExecutionContext(options, globalConfigHandle); } private static Work createMockWork(Windmill.WorkItem workItem, Watermarks watermarks) { @@ -158,25 +171,40 @@ COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.cla Instant::now); } + private void start(Work work) { + start(executionContext, work, null); + } + + private void start(Work work, Coder keyCoder) { + start(executionContext, work, keyCoder); + } + + private void start(StreamingModeExecutionContext context, Work work) { + start(context, work, null); + } + + private void start(StreamingModeExecutionContext context, Work work, Coder keyCoder) { + context.start( + work, + workExecutor, + /* workQueueExecutor= */ null, + /* budgetHandle= */ null, + keyCoder, + /* keyTransitionListener= */ (k, c) -> {}); + } + @Test public void testTimerInternalsSetTimer() throws Exception { - Windmill.WorkItemCommitRequest.Builder outputBuilder = - Windmill.WorkItemCommitRequest.newBuilder(); NameContext nameContext = NameContextsForTests.nameContextForTest(); DataflowOperationContext operationContext = executionContext.createOperationContext(nameContext); StreamingModeExecutionContext.StepContext stepContext = executionContext.getStepContext(operationContext); - executionContext.start( - "key", + start( createMockWork( Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), - Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), - stateReader, - sideInputStateFetcher, - outputBuilder, - workExecutor); + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); TimerInternals timerInternals = stepContext.timerInternals(); @@ -190,6 +218,7 @@ public void testTimerInternalsSetTimer() throws Exception { executionContext.finishKey(); executionContext.flushState(); + Windmill.WorkItemCommitRequest.Builder outputBuilder = executionContext.getOutputBuilder(); Windmill.Timer timer = outputBuilder.buildPartial().getOutputTimers(0); assertThat(timer.getTag().toStringUtf8(), equalTo("/skey+0:5000")); assertThat(timer.getTimestamp(), equalTo(TimeUnit.MILLISECONDS.toMicros(5000))); @@ -198,9 +227,6 @@ public void testTimerInternalsSetTimer() throws Exception { @Test public void testTimerInternalsProcessingTimeSkew() { - Windmill.WorkItemCommitRequest.Builder outputBuilder = - Windmill.WorkItemCommitRequest.newBuilder(); - NameContext nameContext = NameContextsForTests.nameContextForTest(); DataflowOperationContext operationContext = executionContext.createOperationContext(nameContext); @@ -220,15 +246,10 @@ public void testTimerInternalsProcessingTimeSkew() { .setTimestamp(timerTimestamp.getMillis() * 1000) .setType(Windmill.Timer.Type.REALTIME); - executionContext.start( - "key", + start( createMockWork( workItemBuilder.build(), - Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), - stateReader, - sideInputStateFetcher, - outputBuilder, - workExecutor); + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); TimerInternals timerInternals = stepContext.timerInternals(); assertTrue(timerTimestamp.isBefore(timerInternals.currentProcessingTime())); } @@ -429,35 +450,325 @@ public void testStateTagEncodingBasedOnConfig() { FakeGlobalConfigHandle configHandle = new FakeGlobalConfigHandle( StreamingGlobalConfig.builder().setEnableStateTagEncodingV2(isV2Encoding).build()); - StreamingModeExecutionContext context = createExecutionContext(configHandle); + StreamingModeExecutionContext context = createExecutionContext(options, configHandle); assertEquals(expectedEncoding, context.getWindmillTagEncoding().getClass()); } } @Test public void testSetBacklogBytes() { - Windmill.WorkItemCommitRequest.Builder outputBuilder = - Windmill.WorkItemCommitRequest.newBuilder(); NameContext nameContext = NameContextsForTests.nameContextForTest(); DataflowOperationContext operationContext = executionContext.createOperationContext(nameContext); StreamingModeExecutionContext.StepContext stepContext = executionContext.getStepContext(operationContext); - executionContext.start( - "key", + start( createMockWork( Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), - Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), - stateReader, - sideInputStateFetcher, - outputBuilder, - workExecutor); + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); stepContext.setBacklogBytes(1234.0); executionContext.finishKey(); executionContext.flushState(); - assertEquals(1234, outputBuilder.getSourceBacklogBytes()); + assertEquals(1234, executionContext.getOutputBuilder().getSourceBacklogBytes()); + } + + @Test + public void testFinishKeyReentrantSafety() { + start( + createMockWork( + Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build())); + + // First call + executionContext.finishKey(); + // Second call - should not throw any Exception + executionContext.finishKey(); + } + + @Test + public void testStart_internalKeyDecoding() throws Exception { + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("decodedKey")) + .setWorkToken(17L) + .build(); + Work work = + createMockWork( + workItem, Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()); + + start(work, org.apache.beam.sdk.coders.StringUtf8Coder.of()); + + assertEquals("decodedKey", executionContext.getKey()); + } + + @Test + public void testAdvance_success() throws Exception { + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + Windmill.WorkItem workItem2 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key2")) + .setWorkToken(2L) + .setKeyGroup(keyGroup) + .build(); + Work work2 = + createMockWork( + workItem2, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + ExecutableWork executableWork2 = ExecutableWork.create(work2, (w, h) -> {}); + + org.mockito.Mockito.when( + mockExecutor.pollWork( + org.mockito.Mockito.eq(COMPUTATION_ID), + org.mockito.Mockito.eq(work1.getKeyGroup()), + org.mockito.Mockito.eq(mockHandle))) + .thenReturn(executableWork2); + + executionContext.start( + work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertTrue(executionContext.advance()); + assertEquals("key2", executionContext.getSerializedKey().toStringUtf8()); + } + + @Test + public void testAdvance_noMoreWork() throws Exception { + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + org.mockito.Mockito.when( + mockExecutor.pollWork( + org.mockito.Mockito.eq(COMPUTATION_ID), + org.mockito.Mockito.eq(work1.getKeyGroup()), + org.mockito.Mockito.eq(mockHandle))) + .thenReturn(null); + + executionContext.start( + work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertFalse(executionContext.advance()); + } + + @Test + public void testAdvance_respectsMaxBatchSize() throws Exception { + DataflowWorkerHarnessOptions optionsWithBatchSize = + PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + optionsWithBatchSize + .as(ExperimentalOptions.class) + .setExperiments(Arrays.asList("windmill_max_key_group_batch_size=1")); + StreamingModeExecutionContext context = + createExecutionContext(optionsWithBatchSize, globalConfigHandle); + + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + context.start(work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertFalse(context.advance()); + org.mockito.Mockito.verifyNoInteractions(mockExecutor); + } + + @Test + public void testAdvance_respectsMaxBatchTime() throws Exception { + DataflowWorkerHarnessOptions optionsWithBatchTime = + PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + optionsWithBatchTime + .as(ExperimentalOptions.class) + .setExperiments(Arrays.asList("windmill_max_key_group_batch_time_ms=0")); + StreamingModeExecutionContext context = + createExecutionContext(optionsWithBatchTime, globalConfigHandle); + + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + context.start(work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertFalse(context.advance()); + org.mockito.Mockito.verifyNoInteractions(mockExecutor); + } + + @Test + public void testAdvance_workFailed() throws Exception { + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + executionContext.start( + work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + work1.setFailed(); + + assertThrows(WorkItemCancelledException.class, () -> executionContext.advance()); + } + + @Test + public void testAdvance_defaultKeyGroup() throws Exception { + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + executionContext.start( + work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertFalse(executionContext.advance()); + org.mockito.Mockito.verifyNoInteractions(mockExecutor); + } + + @Test + public void testAdvance_experimentDisabled() throws Exception { + DataflowWorkerHarnessOptions optionsDisabled = + PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + StreamingModeExecutionContext context = + createExecutionContext(optionsDisabled, globalConfigHandle); + + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + context.start(work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + assertFalse(context.advance()); + org.mockito.Mockito.verifyNoInteractions(mockExecutor); + } + + @Test + public void testAdvance_respectsMaxBatchSinkBytes() throws Exception { + DataflowWorkerHarnessOptions optionsWithSinkBytes = + PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + optionsWithSinkBytes + .as(ExperimentalOptions.class) + .setExperiments( + Arrays.asList( + "unstable_enable_multi_key_bundle", "windmill_max_key_group_batch_sink_bytes=100")); + StreamingModeExecutionContext context = + createExecutionContext(optionsWithSinkBytes, globalConfigHandle); + + BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + BoundedQueueExecutorWorkHandle mockHandle = mock(BoundedQueueExecutorWorkHandle.class); + + Windmill.Uint128Proto keyGroup = + Windmill.Uint128Proto.newBuilder().setHigh(1).setLow(2).build(); + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(1L) + .setKeyGroup(keyGroup) + .build(); + Work work1 = + createMockWork( + workItem1, Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build()); + + context.start(work1, workExecutor, mockExecutor, mockHandle, null, (oldWork, newWork) -> {}); + + context.reportBytesSinked(50); + assertFalse(context.advance()); + org.mockito.Mockito.verify(mockExecutor) + .pollWork(COMPUTATION_ID, work1.getKeyGroup(), mockHandle); + + org.mockito.Mockito.reset(mockExecutor); + + context.reportBytesSinked(60); + assertFalse(context.advance()); + org.mockito.Mockito.verifyNoInteractions(mockExecutor); + } + + @Test + public void testExperimentParsingWithInvalidValues() { + DataflowWorkerHarnessOptions optionsInvalid = + PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + optionsInvalid + .as(ExperimentalOptions.class) + .setExperiments( + Arrays.asList( + "windmill_max_key_group_batch_size=invalid_size", + "windmill_max_key_group_batch_time_ms=invalid_time", + "windmill_max_key_group_batch_sink_bytes=invalid_bytes")); + + // This should not throw NumberFormatException + StreamingModeExecutionContext context = + createExecutionContext(optionsInvalid, globalConfigHandle); + + org.junit.Assert.assertNotNull(context); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java index 539c38eeb1da..b45e0de6447c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java @@ -30,6 +30,7 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.ThreadLocalRandom; +import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.options.ValueProvider; @@ -122,6 +123,7 @@ public void testFinishKeyCalled() throws Exception { .build()) .build(); when(mockContext.getWorkItem()).thenReturn(workItem); + when(mockContext.advance()).thenReturn(false); try (TestWindmillReaderIterator iter = new TestWindmillReaderIterator(mockContext, ValueProvider.StaticValueProvider.of(false))) { @@ -131,6 +133,76 @@ public void testFinishKeyCalled() throws Exception { } } + @Test + public void testAdvanceKeyChaining() throws Exception { + StreamingModeExecutionContext mockContext = mock(StreamingModeExecutionContext.class); + when(mockContext.workIsFailed()).thenReturn(false); + + // Work item A (1 message) + Windmill.WorkItem workItemA = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("keyA")) + .setWorkToken(100L) + .addMessageBundles( + Windmill.InputMessageBundle.newBuilder() + .setSourceComputationId("foo") + .addMessages( + Windmill.Message.newBuilder() + .setTimestamp(1000) + .setData(ByteString.EMPTY) + .build()) + .build()) + .build(); + when(mockContext.getWorkItem()).thenReturn(workItemA); + + // Work item B (1 message) + Windmill.WorkItem workItemB = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("keyB")) + .setWorkToken(200L) + .addMessageBundles( + Windmill.InputMessageBundle.newBuilder() + .setSourceComputationId("foo") + .addMessages( + Windmill.Message.newBuilder() + .setTimestamp(2000) + .setData(ByteString.EMPTY) + .build()) + .build()) + .build(); + + // Set up context.advance() to mock transition + when(mockContext.advance()) + .thenAnswer( + new org.mockito.stubbing.Answer() { + private int count = 0; + + @Override + public Boolean answer(org.mockito.invocation.InvocationOnMock invocation) { + if (count == 0) { + count++; + when(mockContext.getWorkItem()).thenReturn(workItemB); + return true; + } + return false; + } + }); + + try (TestWindmillReaderIterator iter = + new TestWindmillReaderIterator(mockContext, ValueProvider.StaticValueProvider.of(false))) { + assertTrue(iter.start()); + assertEquals(1000L, iter.getCurrent().getValue().longValue()); + + // Advance should trigger context.advance(), transition to workItemB, and decode message from + // workItemB (timestamp 2000) + assertTrue(iter.advance()); + assertEquals(2000L, iter.getCurrent().getValue().longValue()); + + // Next advance should exhaust it and return false + assertFalse(iter.advance()); + } + } + private void testForMessageBundleCounts(int... messageBundleCounts) throws IOException { testForMessageBundleCounts(false, messageBundleCounts); } @@ -179,4 +251,24 @@ private void testForMessageBundleCounts(boolean skipErrors, int... messageBundle assertEquals(Arrays.toString(messageBundleCounts) + skipErrors, expected, actual); } } + + private static Work createMockWork(Windmill.WorkItem workItem) { + return Work.create( + workItem, + workItem.getSerializedSize(), + org.apache.beam.runners.dataflow.worker.streaming.Watermarks.builder() + .setInputDataWatermark(new org.joda.time.Instant(1000)) + .build(), + Work.createProcessingContext( + "computationId", + mock( + org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient + .class), + ignored -> {}, + mock( + org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender + .class)), + false, + org.joda.time.Instant::now); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReaderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReaderTest.java new file mode 100644 index 000000000000..2e7c80330cf0 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReaderTest.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.List; +import org.apache.beam.runners.core.KeyedWorkItem; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; +import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillTagEncodingV1; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; +import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues.FullWindowedValueCoder; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class WindowingWindmillReaderTest { + private StreamingModeExecutionContext mockContext; + private WindowingWindmillReader reader; + + @SuppressWarnings("unchecked") + @Before + public void setUp() { + mockContext = mock(StreamingModeExecutionContext.class); + when(mockContext.workIsFailed()).thenReturn(false); + when(mockContext.getWindmillTagEncoding()).thenReturn(WindmillTagEncodingV1.instance()); + when(mockContext.getDrainMode()).thenReturn(false); + + Coder keyCoder = StringUtf8Coder.of(); + Coder valueCoder = VarLongCoder.of(); + KvCoder kvCoder = KvCoder.of(keyCoder, valueCoder); + WindmillKeyedWorkItem.FakeKeyedWorkItemCoder keyedWorkItemCoder = + (WindmillKeyedWorkItem.FakeKeyedWorkItemCoder) + WindmillKeyedWorkItem.FakeKeyedWorkItemCoder.of(kvCoder); + FullWindowedValueCoder> coder = + FullWindowedValueCoder.of(keyedWorkItemCoder, IntervalWindowCoder.of()); + + reader = + WindowingWindmillReader.create( + coder, mockContext, ValueProvider.StaticValueProvider.of(false)); + } + + private static Work createMockWork(Windmill.WorkItem workItem) { + return Work.create( + workItem, + workItem.getSerializedSize(), + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build(), + Work.createProcessingContext( + "computationId", new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), + false, + Instant::now); + } + + private static ByteString encodeMetadata(List windows) throws IOException { + ByteStringOutputStream stream = new ByteStringOutputStream(); + PaneInfoCoder.INSTANCE.encode(PaneInfo.NO_FIRING, stream); + ListCoder.of(IntervalWindowCoder.of()).encode(windows, stream); + return stream.toByteString(); + } + + private static ByteString encodeValue(long value) throws IOException { + ByteStringOutputStream stream = new ByteStringOutputStream(); + VarLongCoder.of().encode(value, stream); + return stream.toByteString(); + } + + @Test + public void testSingleNonEmptyKey() throws IOException { + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(1000)); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(100L) + .addMessageBundles( + Windmill.InputMessageBundle.newBuilder() + .setSourceComputationId("foo") + .addMessages( + Windmill.Message.newBuilder() + .setTimestamp(1000) + .setData(encodeValue(42L)) + .setMetadata(encodeMetadata(ImmutableList.of(window))) + .build()) + .build()) + .build(); + Work work = createMockWork(workItem); + + when(mockContext.getKey()).thenReturn("key1"); + when(mockContext.getWorkItem()).thenReturn(workItem); + when(mockContext.getWork()).thenReturn(work); + when(mockContext.advance()).thenReturn(false); + + try (NativeReader.NativeReaderIterator>> iter = + reader.iterator()) { + assertTrue(iter.start()); + WindowedValue> current = iter.getCurrent(); + assertEquals("key1", current.getValue().key()); + assertFalse(Iterables.isEmpty(current.getValue().elementsIterable())); + WindowedValue elem = Iterables.getOnlyElement(current.getValue().elementsIterable()); + assertEquals(42L, elem.getValue().longValue()); + + assertFalse(iter.advance()); + verify(mockContext).finishKey(); + } + } + + @Test + public void testSingleEmptyKey() throws IOException { + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(100L) + .build(); // No message bundles or timers + Work work = createMockWork(workItem); + + when(mockContext.getKey()).thenReturn("key1"); + when(mockContext.getWorkItem()).thenReturn(workItem); + when(mockContext.getWork()).thenReturn(work); + when(mockContext.advance()).thenReturn(false); + + try (NativeReader.NativeReaderIterator>> iter = + reader.iterator()) { + assertFalse( + iter.start()); // Should skip the empty key and return false because advance returns false + verify(mockContext).finishKey(); + } + } + + @Test + public void testMultipleKeys_withEmptyAndNonEmpty() throws IOException { + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(1000)); + // Key 1: Empty + Windmill.WorkItem workItem1 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setWorkToken(100L) + .build(); + Work work1 = createMockWork(workItem1); + + // Key 2: Non-empty + Windmill.WorkItem workItem2 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key2")) + .setWorkToken(200L) + .addMessageBundles( + Windmill.InputMessageBundle.newBuilder() + .setSourceComputationId("foo") + .addMessages( + Windmill.Message.newBuilder() + .setTimestamp(2000) + .setData(encodeValue(84L)) + .setMetadata(encodeMetadata(ImmutableList.of(window))) + .build()) + .build()) + .build(); + Work work2 = createMockWork(workItem2); + + // Key 3: Empty + Windmill.WorkItem workItem3 = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key3")) + .setWorkToken(300L) + .build(); + Work work3 = createMockWork(workItem3); + + // Initial state + when(mockContext.getKey()).thenReturn("key1"); + when(mockContext.getWorkItem()).thenReturn(workItem1); + when(mockContext.getWork()).thenReturn(work1); + + // Mock transition behaviour of context.advance() + when(mockContext.advance()) + .thenAnswer( + new org.mockito.stubbing.Answer() { + private int count = 0; + + @Override + public Boolean answer(org.mockito.invocation.InvocationOnMock invocation) { + if (count == 0) { + count++; + when(mockContext.getKey()).thenReturn("key2"); + when(mockContext.getWorkItem()).thenReturn(workItem2); + when(mockContext.getWork()).thenReturn(work2); + return true; + } else if (count == 1) { + count++; + when(mockContext.getKey()).thenReturn("key3"); + when(mockContext.getWorkItem()).thenReturn(workItem3); + when(mockContext.getWork()).thenReturn(work3); + return true; + } + return false; + } + }); + + try (NativeReader.NativeReaderIterator>> iter = + reader.iterator()) { + // Key 1 is empty, so start() calls advance() which calls finishKey(1) and advance() to Key 2. + // Key 2 is non-empty, so start() returns true yielding Key 2. + assertTrue(iter.start()); + assertEquals("key2", iter.getCurrent().getValue().key()); + WindowedValue elem = + Iterables.getOnlyElement(iter.getCurrent().getValue().elementsIterable()); + assertEquals(84L, elem.getValue().longValue()); + + // Next advance() calls finishKey(2), calls advance() to Key 3. + // Key 3 is empty, so it loops, calls finishKey(3), calls advance() which returns false. + // So iter.advance() should return false. + assertFalse(iter.advance()); + + verify(mockContext, times(3)) + .finishKey(); // finishKey should have been called on key1, key2, key3 + } + } + + @Test + public void testWorkItemCancelled() throws IOException { + when(mockContext.workIsFailed()).thenReturn(true); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(0L).build(); + when(mockContext.getWorkItem()).thenReturn(workItem); + + try (NativeReader.NativeReaderIterator>> iter = + reader.iterator()) { + iter.start(); + fail("Expected WorkItemCancelledException"); + } catch (WorkItemCancelledException e) { + // Expected + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index d5cf2948d928..9d43b62d7c38 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -83,6 +83,7 @@ import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.util.PropertyNames; import org.apache.beam.runners.dataflow.worker.DataflowExecutionContext.DataflowExecutionStateTracker; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.KeyTransitionListener; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.StreamingModeExecutionStateRegistry; import org.apache.beam.runners.dataflow.worker.WorkerCustomSources.SplittableOnlyBoundedSource; import org.apache.beam.runners.dataflow.worker.counters.CounterSet; @@ -93,7 +94,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.config.FixedGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcherFactory; import org.apache.beam.runners.dataflow.worker.testing.TestCountingSource; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader.NativeReaderIterator; @@ -101,7 +102,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; -import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; @@ -209,6 +209,16 @@ COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.cla Instant::now); } + private void startContext(StreamingModeExecutionContext context, Work work) { + context.start( + work, + mock(WorkExecutor.class), + /* workQueueExecutor= */ null, + /* budgetHandle= */ null, + /* keyCoder= */ null, + /* keyTransitionListener= */ mock(KeyTransitionListener.class)); + } + private static class SourceProducingSubSourcesInSplit extends MockSource { int numDesiredBundle; int sourceObjectSize; @@ -620,7 +630,13 @@ public void testReadUnboundedReader() throws Exception { executionStateRegistry, globalConfigHandle, Long.MAX_VALUE, - /*throwExceptionOnLargeOutput=*/ false); + /*throwExceptionOnLargeOutput=*/ false, + new HotKeyLogger(), + /*hotKeyLoggingEnabled=*/ false, + /*stepName=*/ "stepName", + "sourceBytesProcessCounterName", + options, + SideInputStateFetcherFactory.fromOptions(options)); options.setNumWorkers(5); int maxElements = 10; @@ -631,8 +647,8 @@ public void testReadUnboundedReader() throws Exception { for (int i = 0; i < 10 * maxElements; /* Incremented in inner loop */ ) { // Initialize streaming context with state from previous iteration. - context.start( - "key", + startContext( + context, createMockWork( Windmill.WorkItem.newBuilder() .setKey(ByteString.copyFromUtf8("0000000000000001")) // key is zero-padded index. @@ -641,11 +657,7 @@ public void testReadUnboundedReader() throws Exception { .setSourceState( Windmill.SourceState.newBuilder().setState(state).build()) // Source state. .build(), - Watermarks.builder().setInputDataWatermark(new Instant(0)).build()), - mock(WindmillStateReader.class), - mock(SideInputStateFetcher.class), - Windmill.WorkItemCommitRequest.newBuilder(), - mock(WorkExecutor.class)); + Watermarks.builder().setInputDataWatermark(new Instant(0)).build())); @SuppressWarnings({"unchecked", "rawtypes"}) NativeReader>>> reader = @@ -992,7 +1004,13 @@ public void testFailedWorkItemsAbort() throws Exception { executionStateRegistry, globalConfigHandle, Long.MAX_VALUE, - /*throwExceptionOnLargeOutput=*/ false); + /*throwExceptionOnLargeOutput=*/ false, + new HotKeyLogger(), + /*hotKeyLoggingEnabled=*/ false, + /*stepName=*/ "stepName", + "sourceBytesProcessCounterName", + options, + SideInputStateFetcherFactory.fromOptions(options)); options.setNumWorkers(5); int maxElements = 100; @@ -1020,13 +1038,7 @@ public void testFailedWorkItemsAbort() throws Exception { mock(HeartbeatSender.class)), false, Instant::now); - context.start( - "key", - dummyWork, - mock(WindmillStateReader.class), - mock(SideInputStateFetcher.class), - Windmill.WorkItemCommitRequest.newBuilder(), - mock(WorkExecutor.class)); + startContext(context, dummyWork); @SuppressWarnings({"unchecked", "rawtypes"}) NativeReader>>> reader = diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java index 0f14efdd0c0b..60d7bb71a9de 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java @@ -565,6 +565,31 @@ public void testFailWork_batchFail() { } } + @Test + public void testGetActiveWork() { + ShardedKey shardedKey = shardedKey("someKey", 1L); + ExecutableWork work = createWork(createWorkItem(1L, 1L, shardedKey)); + + // Initially empty + assertFalse(activeWorkState.getActiveWork(shardedKey, work.id()).isPresent()); + + // Activate work + activeWorkState.activateWorkForKey(work); + + // Should find it now + Optional activeWork = activeWorkState.getActiveWork(shardedKey, work.id()); + assertTrue(activeWork.isPresent()); + assertSame(work, activeWork.get()); + + // Should not find it with different workId + assertFalse(activeWorkState.getActiveWork(shardedKey, workId(2L, 1L)).isPresent()); + assertFalse(activeWorkState.getActiveWork(shardedKey, workId(1L, 2L)).isPresent()); + + // Should not find it with different shardedKey + ShardedKey otherShardedKey = shardedKey("otherKey", 2L); + assertFalse(activeWorkState.getActiveWork(otherShardedKey, work.id()).isPresent()); + } + private static ExecutableWork firstValue(Map map) { Iterator> iterator = map.entrySet().iterator(); if (iterator.hasNext()) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateTest.java new file mode 100644 index 000000000000..6a6edd2b7192 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateTest.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import com.google.api.services.dataflow.model.MapTask; +import java.util.Collections; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ComputationStateTest { + + private final BoundedQueueExecutor mockExecutor = mock(BoundedQueueExecutor.class); + private final WindmillStateCache.ForComputation mockStateCache = + mock(WindmillStateCache.ForComputation.class); + private final HeartbeatSender mockHeartbeatSender = mock(HeartbeatSender.class); + + private ComputationState computationState; + + private static ShardedKey shardedKey(String str, long shardKey) { + return ShardedKey.create(ByteString.copyFromUtf8(str), shardKey); + } + + private ExecutableWork createWork(Windmill.WorkItem workItem) { + return ExecutableWork.create( + Work.create( + workItem, + workItem.getSerializedSize(), + Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), + Work.createProcessingContext( + "computationId", new FakeGetDataClient(), ignored -> {}, mockHeartbeatSender), + false, + Instant::now), + (work, handle) -> {}); + } + + private static Windmill.WorkItem createWorkItem( + long workToken, long cacheToken, ShardedKey shardedKey) { + return Windmill.WorkItem.newBuilder() + .setShardingKey(shardedKey.shardingKey()) + .setKey(shardedKey.key()) + .setWorkToken(workToken) + .setCacheToken(cacheToken) + .build(); + } + + @Before + public void setUp() { + MapTask mapTask = new MapTask(); + mapTask.setStageName("stage"); + mapTask.setSystemName("system"); + computationState = + new ComputationState( + "computationId", mapTask, mockExecutor, Collections.emptyMap(), mockStateCache); + } + + @Test + public void testReExecuteActiveWork_workNotActive() { + ShardedKey shardedKey = shardedKey("key", 1L); + WorkId workId = WorkId.builder().setWorkToken(1L).setCacheToken(1L).build(); + + computationState.reExecuteActiveWork(shardedKey, workId); + + verifyNoInteractions(mockExecutor); + } + + @Test + public void testReExecuteActiveWork_workActive() { + ShardedKey shardedKey = shardedKey("key", 1L); + Windmill.WorkItem workItem = createWorkItem(1L, 1L, shardedKey); + ExecutableWork work = createWork(workItem); + + // Activate work first. This will execute it once. + computationState.activateWork(work); + verify(mockExecutor).execute(work, work.work().getSerializedWorkItemSize()); + + // Now re-execute + computationState.reExecuteActiveWork(shardedKey, work.id()); + verify(mockExecutor).forceExecute(work, work.work().getSerializedWorkItemSize()); + + verifyNoMoreInteractions(mockExecutor); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WorkTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WorkTest.java new file mode 100644 index 000000000000..80ca91da462f --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WorkTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class WorkTest { + + private static Work createTestWork() { + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("key")) + .setWorkToken(1L) + .setShardingKey(2L) + .build(); + return Work.create( + workItem, + workItem.getSerializedSize(), + Watermarks.builder().setInputDataWatermark(Instant.now()).build(), + Work.createProcessingContext( + "comp", + mock( + org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient + .class), + commit -> {}, + mock(HeartbeatSender.class)), + false, + Instant::now); + } + + @Test + public void testSetFailedBeforeListener() { + Work work = createTestWork(); + assertFalse(work.isFailed()); + + work.setFailed(); + assertTrue(work.isFailed()); + + AtomicBoolean listener = new AtomicBoolean(false); + work.setOnFailureListener(listener); + assertTrue(listener.get()); + } + + @Test + public void testSetFailedAfterListener() { + Work work = createTestWork(); + AtomicBoolean listener = new AtomicBoolean(false); + work.setOnFailureListener(listener); + assertFalse(listener.get()); + assertFalse(work.isFailed()); + + work.setFailed(); + assertTrue(work.isFailed()); + assertTrue(listener.get()); + } + + @Test + public void testConcurrentSetFailedAndSetOnFailureListener() throws Exception { + int numTrials = 5000; + ExecutorService executor = Executors.newFixedThreadPool(2); + try { + for (int i = 0; i < numTrials; i++) { + Work work = createTestWork(); + AtomicBoolean listener = new AtomicBoolean(false); + CountDownLatch latch = new CountDownLatch(1); + + Future f1 = + executor.submit( + () -> { + try { + latch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + work.setFailed(); + }); + + Future f2 = + executor.submit( + () -> { + try { + latch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + work.setOnFailureListener(listener); + }); + + latch.countDown(); + f1.get(5, TimeUnit.SECONDS); + f2.get(5, TimeUnit.SECONDS); + + assertTrue("Trial " + i + " failed: work should be failed", work.isFailed()); + assertTrue("Trial " + i + " failed: listener should be set to true", listener.get()); + } + } finally { + executor.shutdownNow(); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java index a98102751fb2..245d600448fe 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java @@ -30,7 +30,10 @@ import java.util.Collection; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; import java.util.function.Consumer; +import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; @@ -82,6 +85,14 @@ private static ExecutableWork createWorkWithCompId( private static ExecutableWork createWorkWithCompIdAndKeyGroup( String computationId, Work.KeyGroup keyGroup, Consumer executeWorkFn) { + return createWorkWithHandle( + computationId, keyGroup, (work, handle) -> executeWorkFn.accept(work)); + } + + private static ExecutableWork createWorkWithHandle( + String computationId, + Work.KeyGroup keyGroup, + BiConsumer executeWorkFn) { WorkItem workItem = WorkItem.newBuilder() .setKey(ByteString.EMPTY) @@ -103,9 +114,7 @@ private static ExecutableWork createWorkWithCompIdAndKeyGroup( computationId, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), false, Instant::now), - (work, handle) -> { - executeWorkFn.accept(work); - }); + executeWorkFn); } private ExecutableWork createSleepProcessWork(CountDownLatch start, CountDownLatch stop) { @@ -406,18 +415,25 @@ public void testRunnableExceptionPropagationDecrementsCounters() throws Exceptio @Test public void testHandleMerge() throws Exception { - BoundedQueueExecutorWorkHandleImpl handle1 = executor.createBudgetHandle(1, 100L); - BoundedQueueExecutorWorkHandleImpl handle2 = executor.createBudgetHandle(2, 200L); + Work work1 = createWork(ignored -> {}).work(); + Work work2 = createWork(ignored -> {}).work(); + Work work3 = createWork(ignored -> {}).work(); + BoundedQueueExecutorWorkHandleImpl handle1 = executor.createBudgetHandle(work1, 100L); + BoundedQueueExecutorWorkHandleImpl handle2 = executor.createBudgetHandle(work2, 200L); + handle2.merge(executor.createBudgetHandle(work3, 0L)); handle1.merge(handle2); // Verify that handle2 has 0 budget and is closed. - assertEquals(0, handle2.elements()); + assertEquals(0, handle2.getWorkBatch().size()); assertEquals(0, handle2.bytes()); assertTrue(handle2.isClosed()); // Verify that handle1 has the combined budget and is not closed. - assertEquals(3, handle1.elements()); + assertEquals(3, handle1.getWorkBatch().size()); + assertTrue(handle1.getWorkBatch().contains(work1)); + assertTrue(handle1.getWorkBatch().contains(work2)); + assertTrue(handle1.getWorkBatch().contains(work3)); assertEquals(300L, handle1.bytes()); assertFalse(handle1.isClosed()); } @@ -449,11 +465,13 @@ public void testPollWork() throws Exception { // 1. Create blocker task to occupy the worker thread CountDownLatch blockerStart = new CountDownLatch(1); CountDownLatch blockerStop = new CountDownLatch(1); + AtomicReference blockerHandleRef = new AtomicReference<>(); ExecutableWork blockerWork = - createWorkWithCompIdAndKeyGroup( + createWorkWithHandle( "blockerComp", DEFAULT_KEY_GROUP, - ignored -> { + (work, handle) -> { + blockerHandleRef.set(handle); blockerStart.countDown(); try { blockerStop.await(); @@ -464,6 +482,9 @@ public void testPollWork() throws Exception { testExecutor.execute(blockerWork, 0); blockerStart.await(); + BoundedQueueExecutorWorkHandleImpl stealHandle = + (BoundedQueueExecutorWorkHandleImpl) blockerHandleRef.get(); + assertNotNull(stealHandle); // 2. Create two distinct key groups Work.KeyGroup keyGroup1 = Work.KeyGroup.create(1, 1); @@ -488,22 +509,18 @@ public void testPollWork() throws Exception { assertEquals(3, testExecutor.elementsOutstanding()); // Steal work2 using pollWork with compA and keyGroup2 - try (BoundedQueueExecutorWorkHandleImpl stealHandle = testExecutor.createBudgetHandle(0, 0L)) { - ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup2, stealHandle); - assertNotNull(stolen); - assertEquals(work2, stolen); - - // Run the stolen task - stolen.run(stealHandle); - targetStart.await(); - } + ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup2, stealHandle); + assertNotNull(stolen); + assertEquals(work2, stolen); + + // Run the stolen task + stolen.run(stealHandle); + targetStart.await(); // Steal work1 using pollWork with compA and keyGroup1 - try (BoundedQueueExecutorWorkHandleImpl stealHandle = testExecutor.createBudgetHandle(0, 0L)) { - ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup1, stealHandle); - assertNotNull(stolen); - assertEquals(work1, stolen); - } + ExecutableWork stolen1 = testExecutor.pollWork("compA", keyGroup1, stealHandle); + assertNotNull(stolen1); + assertEquals(work1, stolen1); // Unblock the blocker and shut down blockerStop.countDown(); @@ -525,11 +542,13 @@ public void testPollWorkWithLinkedBlockingQueue() throws Exception { CountDownLatch blockerStart = new CountDownLatch(1); CountDownLatch blockerStop = new CountDownLatch(1); + AtomicReference blockerHandleRef = new AtomicReference<>(); ExecutableWork blockerWork = - createWorkWithCompIdAndKeyGroup( + createWorkWithHandle( "blockerComp", DEFAULT_KEY_GROUP, - ignored -> { + (work, handle) -> { + blockerHandleRef.set(handle); blockerStart.countDown(); try { blockerStop.await(); @@ -540,15 +559,16 @@ public void testPollWorkWithLinkedBlockingQueue() throws Exception { testExecutor.execute(blockerWork, 0); blockerStart.await(); + BoundedQueueExecutorWorkHandleImpl stealHandle = + (BoundedQueueExecutorWorkHandleImpl) blockerHandleRef.get(); + assertNotNull(stealHandle); Work.KeyGroup keyGroup = Work.KeyGroup.create(1, 1); ExecutableWork work = createWorkWithCompIdAndKeyGroup("compA", keyGroup, ignored -> {}); testExecutor.execute(work, 100); - try (BoundedQueueExecutorWorkHandleImpl stealHandle = testExecutor.createBudgetHandle(0, 0L)) { - ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup, stealHandle); - assertNull(stolen); - } + ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup, stealHandle); + assertNull(stolen); blockerStop.countDown(); testExecutor.shutdown(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java index 994aa2030f3f..307cbde36989 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java @@ -63,7 +63,6 @@ public static Iterable data() { } @Parameterized.Parameter public boolean fairQueue; - private BoundedQueueExecutor executor; @Before @@ -116,7 +115,7 @@ private QueuedWork createQueuedWork( false, Instant::now), (w, h) -> {}); - return new QueuedWork(work, executor.createBudgetHandle(1, workBytes)); + return new QueuedWork(work, executor.createBudgetHandle(work.work(), workBytes)); } private static class NoOpRunnable implements Runnable { @@ -312,7 +311,6 @@ public String toString() { } })); } - // Start producers for (int i = 0; i < producerThreads; i++) { futures.add( @@ -470,7 +468,6 @@ public void testPollWorkWithKeyGroup() { QueuedWork polledNotExist = queue.pollWork("compA", keyGroupNotExist); assertNull(polledNotExist); assertEquals(2, queue.size()); - // Poll with keyGroup2 first - should return workA2 QueuedWork polledA2 = queue.pollWork("compA", keyGroup2); assertNotNull(polledA2); @@ -485,7 +482,6 @@ public void testPollWorkWithKeyGroup() { assertNotNull(polledA1); assertEquals(workA1, polledA1); assertTrue(queue.isEmpty()); - polledNotExist = queue.pollWork("compA", keyGroupNotExist); assertNull(polledNotExist); assertTrue(queue.isEmpty()); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java index 5c3132ae471d..3da740d53361 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java @@ -129,9 +129,9 @@ public void testCommit() { for (Commit commit : commits) { Windmill.WorkItemCommitRequest request = - committed.get(commit.work().getWorkItem().getWorkToken()); + committed.get(commit.workBatch().get(0).getWorkItem().getWorkToken()); assertNotNull(request); - assertThat(request).isEqualTo(commit.request()); + assertThat(request).isEqualTo(commit.singleKeyRequest().get()); } assertThat(completeCommits).hasSize(commits.size()); @@ -141,12 +141,13 @@ public void testCommit() { (CompleteCommit completeCommit, Commit commit) -> completeCommit.computationId().equals(commit.computationId()) && completeCommit.status() == Windmill.CommitStatus.OK - && completeCommit.workId().equals(commit.work().id()) + && completeCommit.workId().equals(commit.workBatch().get(0).id()) && completeCommit .shardedKey() .equals( ShardedKey.create( - commit.request().getKey(), commit.request().getShardingKey())), + commit.singleKeyRequest().get().getKey(), + commit.singleKeyRequest().get().getShardingKey())), "expected to equal")) .containsExactlyElementsIn(commits); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java index 01197622c24d..a48159338132 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -53,6 +53,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.WorkId; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; @@ -62,6 +63,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.Duration; import org.joda.time.Instant; @@ -134,12 +136,11 @@ private static ComputationState createComputationState(String computationId) { null); } - private static CompleteCommit asCompleteCommit(Commit commit, Windmill.CommitStatus status) { - if (commit.work().isFailed()) { - return CompleteCommit.forFailedWork(commit); - } - - return CompleteCommit.create(commit, status); + private static CompleteCommit asCompleteCommit( + String computationId, Work work, Windmill.CommitStatus status) { + Windmill.CommitStatus finalStatus = work.isFailed() ? Windmill.CommitStatus.ABORTED : status; + return CompleteCommit.create( + computationId, work.getShardedKey(), work.id(), finalStatus, /* retryableFailure= */ false); } @Before @@ -186,10 +187,14 @@ public void testCommit_sendsCommitsToStreamingEngine() { waitForExpectedSetSize(completeCommits, 5); for (Commit commit : commits) { - WorkItemCommitRequest request = committed.get(commit.work().getWorkItem().getWorkToken()); + WorkItemCommitRequest request = + committed.get(commit.workBatch().get(0).getWorkItem().getWorkToken()); assertNotNull(request); - assertThat(request).isEqualTo(commit.request()); - assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); + assertThat(request).isEqualTo(commit.singleKeyRequest().get()); + assertThat(completeCommits) + .contains( + asCompleteCommit( + commit.computationId(), commit.workBatch().get(0), Windmill.CommitStatus.OK)); } workCommitter.stop(); @@ -224,14 +229,24 @@ public void testCommit_handlesFailedCommits() { waitForExpectedSetSize(completeCommits, 10); for (Commit commit : commits) { - if (commit.work().isFailed()) { + if (commit.isFailed()) { assertThat(completeCommits) - .contains(asCompleteCommit(commit, Windmill.CommitStatus.ABORTED)); - assertThat(committed).doesNotContainKey(commit.work().getWorkItem().getWorkToken()); + .contains( + asCompleteCommit( + commit.computationId(), + commit.workBatch().get(0), + Windmill.CommitStatus.ABORTED)); + assertThat(committed) + .doesNotContainKey(commit.workBatch().get(0).getWorkItem().getWorkToken()); } else { - assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); + assertThat(completeCommits) + .contains( + asCompleteCommit( + commit.computationId(), commit.workBatch().get(0), Windmill.CommitStatus.OK)); assertThat(committed) - .containsEntry(commit.work().getWorkItem().getWorkToken(), commit.request()); + .containsEntry( + commit.workBatch().get(0).getWorkItem().getWorkToken(), + commit.singleKeyRequest().get()); } } @@ -282,11 +297,16 @@ public void testCommit_handlesCompleteCommits_commitStatusNotOK() { waitForExpectedSetSize(completeCommits, commits.size()); for (Commit commit : commits) { - WorkItemCommitRequest request = committed.get(commit.work().getWorkItem().getWorkToken()); + WorkItemCommitRequest request = + committed.get(commit.workBatch().get(0).getWorkItem().getWorkToken()); assertNotNull(request); - assertThat(request).isEqualTo(commit.request()); + assertThat(request).isEqualTo(commit.singleKeyRequest().get()); assertThat(completeCommits) - .contains(asCompleteCommit(commit, expectedCommitStatus.get(commit.work().id()))); + .contains( + asCompleteCommit( + commit.computationId(), + commit.workBatch().get(0), + expectedCommitStatus.get(commit.workBatch().get(0).id()))); } workCommitter.stop(); @@ -313,6 +333,14 @@ public boolean commitWorkItem( return false; } + @Override + public boolean commitMultiKeyWorkItem( + String computation, + Windmill.MultiKeyWorkItemCommitRequest request, + Consumer onDone) { + return false; + } + @Override public void flush() {} }; @@ -367,10 +395,11 @@ public void shutdown() {} assertThat(commits.size()).isEqualTo(completeCommits.size()); for (CompleteCommit completeCommit : completeCommits) { assertThat(completeCommit.status()).isEqualTo(Windmill.CommitStatus.ABORTED); + assertThat(completeCommit.retryableFailure()).isFalse(); } for (Commit commit : commits) { - assertTrue(commit.work().isFailed()); + assertTrue(commit.isFailed()); } } @@ -409,10 +438,14 @@ public void testMultipleCommitSendersSingleStream() { waitForExpectedSetSize(completeCommits, commits.size()); for (Commit commit : commits) { - WorkItemCommitRequest request = committed.get(commit.work().getWorkItem().getWorkToken()); + WorkItemCommitRequest request = + committed.get(commit.workBatch().get(0).getWorkItem().getWorkToken()); assertNotNull(request); - assertThat(request).isEqualTo(commit.request()); - assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); + assertThat(request).isEqualTo(commit.singleKeyRequest().get()); + assertThat(completeCommits) + .contains( + asCompleteCommit( + commit.computationId(), commit.workBatch().get(0), Windmill.CommitStatus.OK)); } workCommitter.stop(); @@ -474,4 +507,242 @@ public void testStop_drainsCommitQueue_concurrentCommit() waitForExpectedSetSize(completeCommits, sentCommits.intValue()); } + + @Test + public void testCommit_multiKeyCommitFailedWork() { + Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); + workCommitter = createWorkCommitter(completeCommits::add); + + Work workA = createMockWork(101L); + Work workB = createMockWork(102L); + Work workC = createMockWork(103L); + + // Mark non-primary key B as failed + workB.setFailed(); + + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workA.getWorkItem().getKey()) + .setShardingKey(workA.getWorkItem().getShardingKey()) + .setWorkToken(workA.getWorkItem().getWorkToken()) + .setCacheToken(workA.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workB.getWorkItem().getKey()) + .setShardingKey(workB.getWorkItem().getShardingKey()) + .setWorkToken(workB.getWorkItem().getWorkToken()) + .setCacheToken(workB.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workC.getWorkItem().getKey()) + .setShardingKey(workC.getWorkItem().getShardingKey()) + .setWorkToken(workC.getWorkItem().getWorkToken()) + .setCacheToken(workC.getWorkItem().getCacheToken()) + .build()) + .build(); + + Commit commit = + Commit.createMultiKey( + multiKeyRequest, + createComputationState("computationId"), + ImmutableList.of(workA, workB, workC)); + + workCommitter.start(); + workCommitter.commit(commit); + + // The entire batch must be aborted immediately without making network calls + waitForExpectedSetSize(completeCommits, 3); + + // Verify all three works are aborted individually + assertThat(completeCommits) + .containsExactly( + CompleteCommit.create( + "computationId", + workA.getShardedKey(), + workA.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ true), + CompleteCommit.create( + "computationId", + workB.getShardedKey(), + workB.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ false), + CompleteCommit.create( + "computationId", + workC.getShardedKey(), + workC.id(), + CommitStatus.ABORTED, + /* retryableFailure= */ true)); + + // Verify that valid work was not marked failed + assertThat(workA.isFailed()).isFalse(); + assertThat(workC.isFailed()).isFalse(); + assertThat(workB.isFailed()).isTrue(); + + workCommitter.stop(); + } + + @Test + public void testCommit_multiKeyCommitSuccess() { + Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); + workCommitter = createWorkCommitter(completeCommits::add); + + Work workA = createMockWork(101L); + Work workB = createMockWork(102L); + Work workC = createMockWork(103L); + + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workA.getWorkItem().getKey()) + .setShardingKey(workA.getWorkItem().getShardingKey()) + .setWorkToken(workA.getWorkItem().getWorkToken()) + .setCacheToken(workA.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workB.getWorkItem().getKey()) + .setShardingKey(workB.getWorkItem().getShardingKey()) + .setWorkToken(workB.getWorkItem().getWorkToken()) + .setCacheToken(workB.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workC.getWorkItem().getKey()) + .setShardingKey(workC.getWorkItem().getShardingKey()) + .setWorkToken(workC.getWorkItem().getWorkToken()) + .setCacheToken(workC.getWorkItem().getCacheToken()) + .build()) + .build(); + + Commit commit = + Commit.createMultiKey( + multiKeyRequest, + createComputationState("computationId"), + ImmutableList.of(workA, workB, workC)); + + workCommitter.start(); + workCommitter.commit(commit); + + // Wait for the server to receive and process the commits + fakeWindmillServer.waitForAndGetCommits(3); + waitForExpectedSetSize(completeCommits, 3); + + // Verify that FakeWindmillServer received all 3 work requests in multiKeyCommitsReceived + List multiKeyCommits = + fakeWindmillServer.getMultiKeyCommitsReceived(); + assertThat(multiKeyCommits).hasSize(1); + assertThat(multiKeyCommits.get(0)).isEqualTo(multiKeyRequest); + + // Verify all three works are completed successfully + assertThat(completeCommits) + .containsExactly( + CompleteCommit.create( + "computationId", + workA.getShardedKey(), + workA.id(), + CommitStatus.OK, + /* retryableFailure= */ false), + CompleteCommit.create( + "computationId", + workB.getShardedKey(), + workB.id(), + CommitStatus.OK, + /* retryableFailure= */ false), + CompleteCommit.create( + "computationId", + workC.getShardedKey(), + workC.id(), + CommitStatus.OK, + /* retryableFailure= */ false)); + + workCommitter.stop(); + } + + @Test + public void testCommit_multiKeyCommitStatusNotOK() { + Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); + workCommitter = createWorkCommitter(completeCommits::add); + + Work workA = createMockWork(101L); + Work workB = createMockWork(102L); + Work workC = createMockWork(103L); + + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workA.getWorkItem().getKey()) + .setShardingKey(workA.getWorkItem().getShardingKey()) + .setWorkToken(workA.getWorkItem().getWorkToken()) + .setCacheToken(workA.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workB.getWorkItem().getKey()) + .setShardingKey(workB.getWorkItem().getShardingKey()) + .setWorkToken(workB.getWorkItem().getWorkToken()) + .setCacheToken(workB.getWorkItem().getCacheToken()) + .build()) + .addRequests( + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(workC.getWorkItem().getKey()) + .setShardingKey(workC.getWorkItem().getShardingKey()) + .setWorkToken(workC.getWorkItem().getWorkToken()) + .setCacheToken(workC.getWorkItem().getCacheToken()) + .build()) + .build(); + + Commit commit = + Commit.createMultiKey( + multiKeyRequest, + createComputationState("computationId"), + ImmutableList.of(workA, workB, workC)); + + // Offer NOT_FOUND status for one of the works. + fakeWindmillServer.whenCommitWorkStreamCalled().put(workB.id(), CommitStatus.NOT_FOUND); + + workCommitter.start(); + workCommitter.commit(commit); + + // Wait for the server to receive and process the commits + fakeWindmillServer.waitForAndGetCommits(3); + waitForExpectedSetSize(completeCommits, 3); + + // Verify that FakeWindmillServer received the multi-key commit + List multiKeyCommits = + fakeWindmillServer.getMultiKeyCommitsReceived(); + assertThat(multiKeyCommits).hasSize(1); + assertThat(multiKeyCommits.get(0)).isEqualTo(multiKeyRequest); + + // Verify all three works in the multi-key commit are completed with NOT_FOUND status + assertThat(completeCommits) + .containsExactly( + CompleteCommit.create( + "computationId", + workA.getShardedKey(), + workA.id(), + CommitStatus.NOT_FOUND, + /* retryableFailure= */ false), + CompleteCommit.create( + "computationId", + workB.getShardedKey(), + workB.id(), + CommitStatus.NOT_FOUND, + /* retryableFailure= */ false), + CompleteCommit.create( + "computationId", + workC.getShardedKey(), + workC.id(), + CommitStatus.NOT_FOUND, + /* retryableFailure= */ false)); + + workCommitter.stop(); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java index e9fd55fa5668..fc8348a68ce6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -1133,6 +1133,264 @@ public void testCommitWorkItem_multiplePhysicalStreams_multipleHandovers_halfClo assertTrue(commitWorkStream.awaitTermination(10, TimeUnit.SECONDS)); } + @Test + public void testCommit_multiKeyCommit() throws Exception { + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + CompletableFuture commitStatusFuture = new CompletableFuture<>(); + + // 1. Construct two individual WorkItemCommitRequests + long shardingKey1 = 101L; + long workToken1 = 201L; + long cacheToken1 = 301L; + long shardingKey2 = 102L; + long workToken2 = 202L; + long cacheToken2 = 302L; + Windmill.WorkItemCommitRequest request1 = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setShardingKey(shardingKey1) + .setWorkToken(workToken1) + .setCacheToken(cacheToken1) + .build(); + Windmill.WorkItemCommitRequest request2 = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("key2")) + .setShardingKey(shardingKey2) + .setWorkToken(workToken2) + .setCacheToken(cacheToken2) + .build(); + + // 2. Wrap them into a MultiKeyWorkItemCommitRequest + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests(request1) + .addRequests(request2) + .build(); + + // 3. Commit the multi-key work item using the request batcher + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitMultiKeyWorkItem( + COMPUTATION_ID, multiKeyRequest, commitStatusFuture::complete)); + } + + // 4. Receive and assert request properties on FakeWindmillGrpcService + Windmill.StreamingCommitWorkRequest request = streamInfo.requests.take(); + assertThat(request.getCommitChunkCount()).isEqualTo(1); + + Windmill.StreamingCommitRequestChunk chunk = request.getCommitChunk(0); + + // Assert that the commit type is correctly identified as COMMIT_TYPE_MULTI_KEY + assertThat(chunk.getCommitType()) + .isEqualTo(Windmill.StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY); + + // Assert that the routing sharding key is mapped to the first request's sharding key + assertThat(chunk.getShardingKey()).isEqualTo(request1.getShardingKey()); + + // Assert that the serialized payload matches the input multiKeyRequest + Windmill.MultiKeyWorkItemCommitRequest parsedRequest = + Windmill.MultiKeyWorkItemCommitRequest.parseFrom(chunk.getSerializedWorkItemCommit()); + assertThat(parsedRequest).isEqualTo(multiKeyRequest); + + // 5. Respond with the generated requestId to complete the commit + long requestId = chunk.getRequestId(); + streamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(requestId).build()); + + // 6. Verify callback completed successfully with CommitStatus.OK + assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); + } + + @Test + public void testCommit_multiKeyCommit_multichunk() throws Exception { + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + CompletableFuture commitStatusFuture = new CompletableFuture<>(); + + long shardingKey1 = 101L; + long workToken1 = 201L; + long cacheToken1 = 301L; + long shardingKey2 = 102L; + long workToken2 = 202L; + long cacheToken2 = 302L; + + Windmill.WorkItemCommitRequest request1 = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setShardingKey(shardingKey1) + .setWorkToken(workToken1) + .setCacheToken(cacheToken1) + .addBagUpdates(Windmill.TagBag.newBuilder().setTag(LARGE_BYTE_STRING).build()) + .build(); + + Windmill.WorkItemCommitRequest request2 = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("key2")) + .setShardingKey(shardingKey2) + .setWorkToken(workToken2) + .setCacheToken(cacheToken2) + .build(); + + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder() + .addRequests(request1) + .addRequests(request2) + .build(); + + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitMultiKeyWorkItem( + COMPUTATION_ID, multiKeyRequest, commitStatusFuture::complete)); + } + + Windmill.StreamingCommitWorkRequest requestChunk1 = streamInfo.requests.take(); + assertThat(requestChunk1.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk chunk1 = requestChunk1.getCommitChunk(0); + + assertThat(chunk1.getCommitType()) + .isEqualTo(Windmill.StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY); + assertThat(chunk1.getShardingKey()).isEqualTo(request1.getShardingKey()); + assertThat(chunk1.getRemainingBytesForWorkItem()).isGreaterThan(0); + + Windmill.StreamingCommitWorkRequest requestChunk2 = streamInfo.requests.take(); + assertThat(requestChunk2.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk chunk2 = requestChunk2.getCommitChunk(0); + + assertThat(chunk2.getCommitType()) + .isEqualTo(Windmill.StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY); + assertThat(chunk2.getShardingKey()).isEqualTo(request1.getShardingKey()); + assertThat(chunk2.getRemainingBytesForWorkItem()).isEqualTo(0); + + ByteString reconstructedBytes = + chunk1.getSerializedWorkItemCommit().concat(chunk2.getSerializedWorkItemCommit()); + Windmill.MultiKeyWorkItemCommitRequest parsedRequest = + Windmill.MultiKeyWorkItemCommitRequest.parseFrom(reconstructedBytes); + assertThat(parsedRequest).isEqualTo(multiKeyRequest); + + long requestId = chunk1.getRequestId(); + assertThat(chunk2.getRequestId()).isEqualTo(requestId); + + streamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(requestId).build()); + + assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); + } + + @Test + public void testCommitMultiKeyWorkItem_retryOnNewStream() throws Exception { + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + CompletableFuture commitStatusFuture = new CompletableFuture<>(); + + long shardingKey1 = 101L; + long workToken1 = 201L; + long cacheToken1 = 301L; + Windmill.WorkItemCommitRequest request1 = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("key1")) + .setShardingKey(shardingKey1) + .setWorkToken(workToken1) + .setCacheToken(cacheToken1) + .build(); + Windmill.MultiKeyWorkItemCommitRequest multiKeyRequest = + Windmill.MultiKeyWorkItemCommitRequest.newBuilder().addRequests(request1).build(); + + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitMultiKeyWorkItem( + COMPUTATION_ID, multiKeyRequest, commitStatusFuture::complete)); + } + + Windmill.StreamingCommitWorkRequest request = streamInfo.requests.take(); + assertThat(request.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk chunk = request.getCommitChunk(0); + assertThat(chunk.getCommitType()) + .isEqualTo(Windmill.StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY); + long requestId = chunk.getRequestId(); + + streamInfo.responseObserver.onError(new IOException("test error")); + + FakeWindmillGrpcService.CommitStreamInfo reconnectStreamInfo = + waitForConnectionAndConsumeHeader(); + Windmill.StreamingCommitWorkRequest reconnectRequest = reconnectStreamInfo.requests.take(); + assertThat(reconnectRequest.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk reconnectChunk = reconnectRequest.getCommitChunk(0); + assertThat(reconnectChunk.getCommitType()) + .isEqualTo(Windmill.StreamingCommitRequestChunk.CommitType.COMMIT_TYPE_MULTI_KEY); + assertThat(reconnectChunk.getRequestId()).isEqualTo(requestId); + + Windmill.MultiKeyWorkItemCommitRequest parsedRequest = + Windmill.MultiKeyWorkItemCommitRequest.parseFrom( + reconnectChunk.getSerializedWorkItemCommit()); + assertThat(parsedRequest).isEqualTo(multiKeyRequest); + + reconnectStreamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(requestId).build()); + assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); + } + + @Test + public void testCommitWorkItem_retryOnNewStream_multichunk() throws Exception { + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + CompletableFuture commitStatusFuture = new CompletableFuture<>(); + + Windmill.WorkItemCommitRequest largeRequest = + workItemCommitRequest(1) + .toBuilder() + .addBagUpdates(Windmill.TagBag.newBuilder().setTag(LARGE_BYTE_STRING).build()) + .build(); + + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem(COMPUTATION_ID, largeRequest, commitStatusFuture::complete)); + } + + Windmill.StreamingCommitWorkRequest requestChunk1 = streamInfo.requests.take(); + assertThat(requestChunk1.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk chunk1 = requestChunk1.getCommitChunk(0); + long requestId = chunk1.getRequestId(); + assertThat(chunk1.getRemainingBytesForWorkItem()).isGreaterThan(0); + + Windmill.StreamingCommitWorkRequest requestChunk2 = streamInfo.requests.take(); + assertThat(requestChunk2.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk chunk2 = requestChunk2.getCommitChunk(0); + assertThat(chunk2.getRequestId()).isEqualTo(requestId); + assertThat(chunk2.getRemainingBytesForWorkItem()).isEqualTo(0); + + streamInfo.responseObserver.onError(new IOException("test error")); + + FakeWindmillGrpcService.CommitStreamInfo reconnectStreamInfo = + waitForConnectionAndConsumeHeader(); + + Windmill.StreamingCommitWorkRequest reconnectChunk1 = reconnectStreamInfo.requests.take(); + assertThat(reconnectChunk1.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk reconChunk1 = reconnectChunk1.getCommitChunk(0); + assertThat(reconChunk1.getRequestId()).isEqualTo(requestId); + assertThat(reconChunk1.getRemainingBytesForWorkItem()).isGreaterThan(0); + + Windmill.StreamingCommitWorkRequest reconnectChunk2 = reconnectStreamInfo.requests.take(); + assertThat(reconnectChunk2.getCommitChunkCount()).isEqualTo(1); + Windmill.StreamingCommitRequestChunk reconChunk2 = reconnectChunk2.getCommitChunk(0); + assertThat(reconChunk2.getRequestId()).isEqualTo(requestId); + assertThat(reconChunk2.getRemainingBytesForWorkItem()).isEqualTo(0); + + ByteString reconstructedBytes = + reconChunk1.getSerializedWorkItemCommit().concat(reconChunk2.getSerializedWorkItemCommit()); + Windmill.WorkItemCommitRequest parsedRequest = + Windmill.WorkItemCommitRequest.parseFrom(reconstructedBytes); + assertThat(parsedRequest).isEqualTo(largeRequest); + + reconnectStreamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(requestId).build()); + assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); + } + private FakeWindmillGrpcService.CommitStreamInfo waitForConnectionAndConsumeHeader() { try { FakeWindmillGrpcService.CommitStreamInfo info = fakeService.waitForConnectedCommitStream(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java index 1611fdac25dc..65637437a0a0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java @@ -35,9 +35,9 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.Future; -import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException; import org.apache.beam.runners.dataflow.worker.WindmillStateTestUtils; import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; +import org.apache.beam.runners.dataflow.worker.WorkCancelingException; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.SortedListEntry; @@ -1572,16 +1572,16 @@ public void testKeyTokenInvalid() throws Exception { try { watermarkFuture.get(); - fail("Expected KeyTokenInvalidException"); + fail("Expected WorkCancelingException"); } catch (Exception e) { - assertTrue(KeyTokenInvalidException.isKeyTokenInvalidException(e)); + assertTrue(WorkCancelingException.isWorkCancelingException(e)); } try { bagFuture.get(); - fail("Expected KeyTokenInvalidException"); + fail("Expected WorkCancelingException"); } catch (Exception e) { - assertTrue(KeyTokenInvalidException.isKeyTokenInvalidException(e)); + assertTrue(WorkCancelingException.isWorkCancelingException(e)); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java index 0610ed44c27f..ce9fe53f47d3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java @@ -21,15 +21,15 @@ import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; +import java.util.Arrays; import java.util.HashSet; +import java.util.List; import java.util.Optional; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.function.Supplier; -import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException; -import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; @@ -109,38 +109,22 @@ private static ExecutableWork createWork(Consumer processWorkFn) { } @Test - public void logAndProcessFailure_doesNotRetryKeyTokenInvalidException() throws Throwable { + public void logAndProcessFailureBatch_doesNotRetryFailedWork() throws Throwable { Set executedWork = new HashSet<>(); ExecutableWork work = createWork(executedWork::add); + work.work().setFailed(); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingEngineFailureReporter()); Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, work, new KeyTokenInvalidException("key"), invalidWork::add); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, List.of(work), new RuntimeException(), invalidWork::add); assertThat(executedWork).isEmpty(); assertThat(invalidWork).containsExactly(work.work()); } @Test - public void logAndProcessFailure_doesNotRetryWhenWorkItemCancelled() throws Throwable { - Set executedWork = new HashSet<>(); - ExecutableWork work = createWork(executedWork::add); - WorkFailureProcessor workFailureProcessor = - createWorkFailureProcessor(streamingEngineFailureReporter()); - Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, - work, - new WorkItemCancelledException(work.getWorkItem().getShardingKey()), - invalidWork::add); - - assertThat(executedWork).isEmpty(); - assertThat(invalidWork).containsExactly(work.work()); - } - - @Test - public void logAndProcessFailure_doesNotRetryOOM() { + public void logAndProcessFailureBatch_doesNotRetryOOM() { Set executedWork = new HashSet<>(); ExecutableWork work = createWork(executedWork::add); WorkFailureProcessor workFailureProcessor = @@ -149,69 +133,120 @@ public void logAndProcessFailure_doesNotRetryOOM() { assertThrows( OutOfMemoryError.class, () -> - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, work, new OutOfMemoryError(), invalidWork::add)); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, + Arrays.asList(work), + new OutOfMemoryError(), + invalidWork::add)); assertThat(executedWork).isEmpty(); } @Test - public void logAndProcessFailure_doesNotRetryWhenFailureReporterMarksAsNonRetryable() + public void logAndProcessFailureBatch_doesNotRetryWhenFailureReporterMarksAsNonRetryable() throws Throwable { Set executedWork = new HashSet<>(); ExecutableWork work = createWork(executedWork::add); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingApplianceFailureReporter(true)); Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, work, new RuntimeException(), invalidWork::add); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, Arrays.asList(work), new RuntimeException(), invalidWork::add); assertThat(executedWork).isEmpty(); assertThat(invalidWork).containsExactly(work.work()); } @Test - public void logAndProcessFailure_doesNotRetryAfterLocalRetryTimeout() throws Throwable { + public void logAndProcessFailureBatch_doesNotRetryAfterLocalRetryTimeout() throws Throwable { Set executedWork = new HashSet<>(); ExecutableWork veryOldWork = createWork(() -> Instant.now().minus(Duration.standardDays(30)), executedWork::add); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingEngineFailureReporter()); Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, veryOldWork, new RuntimeException(), invalidWork::add); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, + Arrays.asList(veryOldWork), + new RuntimeException(), + invalidWork::add); assertThat(executedWork).isEmpty(); assertThat(invalidWork).contains(veryOldWork.work()); } @Test - public void logAndProcessFailure_retriesOnUncaughtUnhandledException_streamingEngine() + public void logAndProcessFailureBatch_retriesOnUncaughtUnhandledException_streamingEngine() throws Throwable { CountDownLatch runWork = new CountDownLatch(1); ExecutableWork work = createWork(ignored -> runWork.countDown()); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingEngineFailureReporter()); Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, work, new RuntimeException(), invalidWork::add); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, Arrays.asList(work), new RuntimeException(), invalidWork::add); runWork.await(); assertThat(invalidWork).isEmpty(); } @Test - public void logAndProcessFailure_retriesOnUncaughtUnhandledException_streamingAppliance() + public void logAndProcessFailureBatch_retriesOnUncaughtUnhandledException_streamingAppliance() throws Throwable { CountDownLatch runWork = new CountDownLatch(1); ExecutableWork work = createWork(ignored -> runWork.countDown()); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingApplianceFailureReporter(false)); Set invalidWork = new HashSet<>(); - workFailureProcessor.logAndProcessFailure( - DEFAULT_COMPUTATION_ID, work, new RuntimeException(), invalidWork::add); + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, Arrays.asList(work), new RuntimeException(), invalidWork::add); runWork.await(); assertThat(invalidWork).isEmpty(); } + + @Test + public void logAndProcessFailureBatch_retryAll() throws Throwable { + CountDownLatch runWork1 = new CountDownLatch(1); + CountDownLatch runWork2 = new CountDownLatch(1); + ExecutableWork work1 = createWork(ignored -> runWork1.countDown()); + ExecutableWork work2 = createWork(ignored -> runWork2.countDown()); + + WorkFailureProcessor workFailureProcessor = + createWorkFailureProcessor(streamingEngineFailureReporter()); + Set invalidWork = new HashSet<>(); + + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, + Arrays.asList(work1, work2), + new RuntimeException(), + invalidWork::add); + + runWork1.await(); + runWork2.await(); + assertThat(invalidWork).isEmpty(); + } + + @Test + public void logAndProcessFailureBatch_mixRetryAndAbort() throws Throwable { + CountDownLatch runWork1 = new CountDownLatch(1); + Set executedWork2 = new HashSet<>(); + ExecutableWork work1 = createWork(ignored -> runWork1.countDown()); + ExecutableWork work2 = createWork(executedWork2::add); + work2.work().setFailed(); + + WorkFailureProcessor workFailureProcessor = + createWorkFailureProcessor(streamingEngineFailureReporter()); + Set invalidWork = new HashSet<>(); + + workFailureProcessor.logAndProcessFailureBatch( + DEFAULT_COMPUTATION_ID, + Arrays.asList(work1, work2), + new RuntimeException(), + invalidWork::add); + + runWork1.await(); + assertThat(executedWork2).isEmpty(); + assertThat(invalidWork).containsExactly(work2.work()); + } } diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto index aaa09c105fc3..a7a99e2ca5a1 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto @@ -678,9 +678,24 @@ message WorkItemCommitRequest { reserved 6, 23; } +message MultiKeyWorkItemCommitRequest { + optional Uint128Proto key_group = 7; + + repeated WorkItemCommitRequest requests = 1; + + repeated OutputMessageBundle output_messages = 2; + + repeated PubSubMessageBundle pubsub_messages = 3; + + repeated int64 finalize_ids = 4 [packed = true]; + + reserved 6; +} + message ComputationCommitWorkRequest { required string computation_id = 1; repeated WorkItemCommitRequest requests = 2; + repeated MultiKeyWorkItemCommitRequest multi_key_requests = 3; } message CommitWorkRequest { @@ -906,6 +921,14 @@ message StreamingCommitRequestChunk { // before handing off to the WindmillHost for processing. optional int64 remaining_bytes_for_work_item = 4; optional bytes serialized_work_item_commit = 5; + + enum CommitType { + COMMIT_TYPE_UNSPECIFIED = 0; + COMMIT_TYPE_SINGLE_KEY = 1; + COMMIT_TYPE_MULTI_KEY = 2; + } + + optional CommitType commit_type = 7; } message StreamingCommitResponse {