Skip to content

Commit

Permalink
fix(jdbc): merge the locked execution with the received execution to …
Browse files Browse the repository at this point in the history
…handle parallel tasks execution

Fixes #2179
  • Loading branch information
loicmathieu committed Oct 2, 2023
1 parent d90945d commit a9f0089
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 26 deletions.
63 changes: 41 additions & 22 deletions jdbc/src/main/java/io/kestra/jdbc/runner/JdbcExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ private void executionQueue(Either<Execution, DeserializationException> either)
}

Executor result = executionRepository.lock(message.getId(), pair -> {
Execution execution = pair.getLeft();
// as tasks can be processed in parallel, we must merge the execution from the database to the one we received in the queue
Execution execution = mergeExecution(pair.getLeft(), message);
ExecutorState executorState = pair.getRight();

final Flow flow = transform(this.flowRepository.findByExecution(execution), execution);
Expand All @@ -296,15 +297,15 @@ private void executionQueue(Either<Execution, DeserializationException> either)

executor = executorService.process(executor);

if (executor.getNexts().size() > 0 && deduplicateNexts(execution, executorState, executor.getNexts())) {
if (!executor.getNexts().isEmpty() && deduplicateNexts(execution, executorState, executor.getNexts())) {
executor.withExecution(
executorService.onNexts(executor.getFlow(), executor.getExecution(), executor.getNexts()),
"onNexts"
);
}

// worker task
if (executor.getWorkerTasks().size() > 0) {
if (!executor.getWorkerTasks().isEmpty()) {
List<WorkerTask> workerTasksDedup = executor
.getWorkerTasks()
.stream()
Expand All @@ -326,26 +327,26 @@ private void executionQueue(Either<Execution, DeserializationException> either)
}

// worker tasks results
if (executor.getWorkerTaskResults().size() > 0) {
if (!executor.getWorkerTaskResults().isEmpty()) {
executor.getWorkerTaskResults()
.forEach(workerTaskResultQueue::emit);
}

// schedulerDelay
if (executor.getExecutionDelays().size() > 0) {
if (!executor.getExecutionDelays().isEmpty()) {
executor.getExecutionDelays()
.forEach(executionDelay -> abstractExecutionDelayStorage.save(executionDelay));
}

// worker task execution watchers
if (executor.getWorkerTaskExecutions().size() > 0) {
if (!executor.getWorkerTaskExecutions().isEmpty()) {
workerTaskExecutionStorage.save(executor.getWorkerTaskExecutions());

List<WorkerTaskExecution> workerTasksExecutionDedup = executor
.getWorkerTaskExecutions()
.stream()
.filter(workerTaskExecution -> this.deduplicateWorkerTaskExecution(execution, executorState, workerTaskExecution.getTaskRun()))
.collect(Collectors.toList());
.toList();

workerTasksExecutionDedup
.forEach(workerTaskExecution -> {
Expand Down Expand Up @@ -409,6 +410,25 @@ private void executionQueue(Either<Execution, DeserializationException> either)
}
}

private Execution mergeExecution(Execution locked, Execution message) {
Execution newExecution = locked;
if (message.getTaskRunList() != null) {
for (TaskRun taskRun : message.getTaskRunList()) {
try {
TaskRun existing = newExecution.findTaskRunByTaskRunId(taskRun.getId());
// if the taskrun from the message is newer than the one from the execution, we replace it!
if (existing != null && taskRun.getState().maxDate().isAfter(existing.getState().maxDate())) {
newExecution = newExecution.withTaskRun(taskRun);
}
}
catch (InternalException e) {
throw new RuntimeException(e);
}
}
}
return newExecution;
}

private void workerTaskResultQueue(Either<WorkerTaskResult, DeserializationException> either) {
if (either.isRight()) {
log.error("Unable to deserialize a worker task result: {}", either.getRight().getMessage());
Expand All @@ -425,20 +445,6 @@ private void workerTaskResultQueue(Either<WorkerTaskResult, DeserializationExcep
executorService.log(log, true, message);
}

// send metrics on terminated
if (message.getTaskRun().getState().isTerminated()) {
metricRegistry
.counter(MetricRegistry.EXECUTOR_TASKRUN_ENDED_COUNT, metricRegistry.tags(message))
.increment();

metricRegistry
.timer(MetricRegistry.EXECUTOR_TASKRUN_ENDED_DURATION, metricRegistry.tags(message))
.record(message.getTaskRun().getState().getDuration());

log.trace("TaskRun terminated: {}", message.getTaskRun());
workerJobRunningRepository.deleteByTaskRunId(message.getTaskRun().getId());
}

Executor executor = executionRepository.lock(message.getTaskRun().getExecutionId(), pair -> {
Execution execution = pair.getLeft();
Executor current = new Executor(execution, null);
Expand All @@ -459,10 +465,23 @@ private void workerTaskResultQueue(Either<WorkerTaskResult, DeserializationExcep
if (newExecution != null) {
current = current.withExecution(newExecution, "addDynamicTaskRun");
}
newExecution = current.getExecution().withTaskRun(message.getTaskRun());
current = current.withExecution(newExecution, "joinWorkerResult");

// send metrics on terminated
if (message.getTaskRun().getState().isTerminated()) {
metricRegistry
.counter(MetricRegistry.EXECUTOR_TASKRUN_ENDED_COUNT, metricRegistry.tags(message))
.increment();

metricRegistry
.timer(MetricRegistry.EXECUTOR_TASKRUN_ENDED_DURATION, metricRegistry.tags(message))
.record(message.getTaskRun().getState().getDuration());
}

// join worker result
return Pair.of(
current.withExecution(current.getExecution().withTaskRun(message.getTaskRun()), "joinWorkerResult"),
current,
pair.getRight()
);
} catch (InternalException e) {
Expand Down
5 changes: 1 addition & 4 deletions jdbc/src/test/java/io/kestra/jdbc/runner/JdbcRunnerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,7 @@ void eachParallelWithSubflowMissing() throws TimeoutException {

assertThat(execution, notNullValue());
assertThat(execution.getState().getCurrent(), is(State.Type.FAILED));
// on JDBC, when using an each parallel, the flow is failed even if not all subtasks of the each parallel are ended as soon as
// there is one failed task FIXME https://github.com/kestra-io/kestra/issues/2179
// so instead of asserting that all tasks FAILED we assert that at least two failed (the each parallel and one of its subtasks)
assertThat(execution.getTaskRunList().stream().filter(taskRun -> taskRun.getState().isFailed()).count(), greaterThanOrEqualTo(2L)); // Should be 3
assertThat(execution.getTaskRunList().stream().filter(taskRun -> taskRun.getState().isFailed()).count(), is(3L));
}

@Test
Expand Down

0 comments on commit a9f0089

Please sign in to comment.