Skip to content

Commit

Permalink
feat(core): introduce DynamicTask for dynamic task generation for a w…
Browse files Browse the repository at this point in the history
…orker task
  • Loading branch information
tchiotludo committed Mar 14, 2022
1 parent e7c73dd commit e05d9d6
Show file tree
Hide file tree
Showing 12 changed files with 152 additions and 52 deletions.
9 changes: 9 additions & 0 deletions core/src/main/java/io/kestra/core/models/flows/State.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ public State(Type state, State actual) {
this.histories.add(new History(this.current, Instant.now()));
}

public static State of(Type state, List<History> histories) {
State result = new State(state);

result.histories.removeIf(history -> true);
result.histories.addAll(histories);

return result;
}

public State withState(Type state) {
if (this.current == state) {
log.warn("Can't change state, already " + current);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
@Getter
public class GraphCluster extends AbstractGraphTask {
@JsonIgnore
private Graph<AbstractGraphTask, Relation> graph = new Graph<>();
private final Graph<AbstractGraphTask, Relation> graph = new Graph<>();

@JsonIgnore
private GraphClusterRoot root;
private final GraphClusterRoot root;

@JsonIgnore
private GraphClusterEnd end;
private final GraphClusterEnd end;

public GraphCluster() {
super();
Expand All @@ -36,10 +36,4 @@ public GraphCluster(Task task, TaskRun taskRun, List<String> values, RelationTyp
graph.addNode(this.root);
graph.addNode(this.end);
}

public GraphCluster(GraphCluster graphTask, TaskRun taskRun, List<String> values) {
super(graphTask.getTask(), taskRun, values, graphTask.getRelationType());

this.graph = graphTask.graph;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package io.kestra.core.models.tasks;

public interface DynamicTask {

}
16 changes: 10 additions & 6 deletions core/src/main/java/io/kestra/core/runners/ExecutorService.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import io.kestra.core.models.executions.TaskRun;
import io.kestra.core.models.flows.Flow;
import io.kestra.core.models.flows.State;
import io.kestra.core.models.tasks.DynamicTask;
import io.kestra.core.models.tasks.FlowableTask;
import io.kestra.core.models.tasks.ResolvedTask;
import io.kestra.core.models.tasks.Task;
import io.kestra.core.services.ConditionService;
import io.kestra.core.tasks.flows.Worker;
import io.micronaut.context.ApplicationContext;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
Expand Down Expand Up @@ -568,6 +568,13 @@ private Executor handleFlowTask(final Executor executor) {
}

public Execution addDynamicTaskRun(Execution execution, Flow flow, WorkerTaskResult workerTaskResult) throws InternalException {
ArrayList<TaskRun> taskRuns = new ArrayList<>(execution.getTaskRunList());

// declared dynamic tasks
if (workerTaskResult.getDynamicTaskRuns() != null) {
taskRuns.addAll(workerTaskResult.getDynamicTaskRuns());
}

// if parent, can be a Worker task that generate dynamic tasks
if (workerTaskResult.getTaskRun().getParentTaskRunId() != null) {
try {
Expand All @@ -576,16 +583,13 @@ public Execution addDynamicTaskRun(Execution execution, Flow flow, WorkerTaskRes
TaskRun parentTaskRun = execution.findTaskRunByTaskRunId(workerTaskResult.getTaskRun().getParentTaskRunId());
Task parentTask = flow.findTaskByTaskId(parentTaskRun.getTaskId());

if (parentTask instanceof Worker) {
ArrayList<TaskRun> taskRuns = new ArrayList<>(execution.getTaskRunList());
if (parentTask instanceof DynamicTask) {
taskRuns.add(workerTaskResult.getTaskRun());

return execution.withTaskRunList(taskRuns);
}
}
}

return null;
return taskRuns.size() > execution.getTaskRunList().size() ? execution.withTaskRunList(taskRuns) : null;
}

public void log(Logger log, Boolean in, WorkerTask value) {
Expand Down
20 changes: 16 additions & 4 deletions core/src/main/java/io/kestra/core/runners/RunContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,21 @@
public class RunContext {
private final static ObjectMapper MAPPER = JacksonMapper.ofJson();

private VariableRenderer variableRenderer;
// Injected
private ApplicationContext applicationContext;
private VariableRenderer variableRenderer;
private StorageInterface storageInterface;
private String envPrefix;
private MetricRegistry meterRegistry;
private Path tempBasedPath;

private URI storageOutputPrefix;
private URI storageExecutionPrefix;
private String envPrefix;
private Map<String, Object> variables;
private List<AbstractMetricEntry<?>> metrics = new ArrayList<>();
private MetricRegistry meterRegistry;
private RunContextLogger runContextLogger;
private Path tempBasedPath;
private final List<WorkerTaskResult> dynamicWorkerTaskResult = new ArrayList<>();

protected transient Path temporaryDirectory;

/**
Expand Down Expand Up @@ -572,6 +576,14 @@ private String metricPrefix() {
return String.join(".", values);
}

public void dynamicWorkerResult(List<WorkerTaskResult> workerTaskResults) {
dynamicWorkerTaskResult.addAll(workerTaskResults);
}

public List<WorkerTaskResult> dynamicWorkerResults() {
return dynamicWorkerTaskResult;
}

public synchronized Path tempDir() {
return this.tempDir(true);
}
Expand Down
9 changes: 6 additions & 3 deletions core/src/main/java/io/kestra/core/runners/Worker.java
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ private WorkerTaskResult run(WorkerTask workerTask, Boolean cleanUp) throws Queu
)
.get(() -> this.runAttempt(current.get()));

// save dynamic WorkerResults since cleanUpTransient will remove them
List<WorkerTaskResult> dynamicWorkerResults = finalWorkerTask.getRunContext().dynamicWorkerResults();

// remove tmp directory
if (cleanUp) {
finalWorkerTask.getRunContext().cleanup();
}
Expand Down Expand Up @@ -259,15 +263,15 @@ private WorkerTaskResult run(WorkerTask workerTask, Boolean cleanUp) throws Queu
// So we just tryed to failed the status of the worker task, in this case, no log can't be happend, just
// changing status must work in order to finish current task (except if we are near the upper bound size).
try {
WorkerTaskResult workerTaskResult = new WorkerTaskResult(finalWorkerTask);
WorkerTaskResult workerTaskResult = new WorkerTaskResult(finalWorkerTask, dynamicWorkerResults);
this.workerTaskResultQueue.emit(workerTaskResult);
return workerTaskResult;
} catch (QueueException e) {
finalWorkerTask = workerTask
.withTaskRun(workerTask.getTaskRun()
.withState(State.Type.FAILED)
);
WorkerTaskResult workerTaskResult = new WorkerTaskResult(finalWorkerTask);
WorkerTaskResult workerTaskResult = new WorkerTaskResult(finalWorkerTask, dynamicWorkerResults);
this.workerTaskResultQueue.emit(workerTaskResult);
return workerTaskResult;
} finally {
Expand Down Expand Up @@ -382,7 +386,6 @@ private List<TaskRunAttempt> addAttempt(WorkerTask workerTask, TaskRunAttempt ta
.build();
}

@SuppressWarnings("UnstableApiUsage")
public AtomicInteger getMetricRunningCount(WorkerTask workerTask) {
String[] tags = this.metricRegistry.tags(workerTask);
Arrays.sort(tags);
Expand Down
19 changes: 19 additions & 0 deletions core/src/main/java/io/kestra/core/runners/WorkerTaskResult.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import lombok.Builder;
import lombok.Value;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import javax.validation.constraints.NotNull;

@Value
Expand All @@ -14,7 +17,23 @@ public class WorkerTaskResult {
@NotNull
TaskRun taskRun;

List<TaskRun> dynamicTaskRuns;

public WorkerTaskResult(TaskRun taskRun) {
this.taskRun = taskRun;
this.dynamicTaskRuns = new ArrayList<>();
}

public WorkerTaskResult(WorkerTask workerTask) {
this.taskRun = workerTask.getTaskRun();
this.dynamicTaskRuns = new ArrayList<>();
}

public WorkerTaskResult(WorkerTask workerTask, List<WorkerTaskResult> dynamicWorkerResults) {
this.taskRun = workerTask.getTaskRun();
this.dynamicTaskRuns = dynamicWorkerResults
.stream()
.map(WorkerTaskResult::getTaskRun)
.collect(Collectors.toList());
}
}
74 changes: 54 additions & 20 deletions core/src/main/java/io/kestra/core/services/ExecutionService.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,22 @@
import io.kestra.core.models.flows.Flow;
import io.kestra.core.models.flows.State;
import io.kestra.core.models.hierarchies.GraphCluster;
import io.kestra.core.models.tasks.Task;
import io.kestra.core.repositories.FlowRepositoryInterface;
import io.kestra.core.tasks.flows.Worker;
import io.kestra.core.utils.IdUtils;
import io.micronaut.context.ApplicationContext;
import io.micronaut.core.annotation.Nullable;

import java.util.*;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;

import static io.kestra.core.utils.Rethrow.throwFunction;
import static io.kestra.core.utils.Rethrow.throwPredicate;

@Singleton
public class ExecutionService {
Expand All @@ -37,19 +41,12 @@ public Execution restart(final Execution execution, @Nullable Integer revision)

final Flow flow = flowRepositoryInterface.findByExecution(execution);

Set<String> taskRunToRestart = this.taskRunWithAncestors(
Set<String> taskRunToRestart = this.taskRunToRestart(
execution,
execution
.getTaskRunList()
.stream()
.filter(taskRun -> taskRun.getState().getCurrent().isFailed())
.collect(Collectors.toList())
flow,
taskRun -> taskRun.getState().getCurrent().isFailed()
);

if (taskRunToRestart.size() == 0) {
throw new IllegalArgumentException("No failed task found to restart execution from !");
}

Map<String, String> mappingTaskRunId = this.mapTaskRunId(execution, revision == null);
final String newExecutionId = revision != null ? IdUtils.create() : null;

Expand Down Expand Up @@ -77,6 +74,50 @@ public Execution restart(final Execution execution, @Nullable Integer revision)
return revision != null ? newExecution.withFlowRevision(revision) : newExecution;
}

private Set<String> taskRunToRestart(Execution execution, Flow flow, Predicate<TaskRun> predicate) throws InternalException {
// Original tasks to be restarted
Set<String> originalTaskRunToRestart = this
.taskRunWithAncestors(
execution,
execution
.getTaskRunList()
.stream()
.filter(predicate)
.collect(Collectors.toList())
);

// we removed Worker
Set<String> finalTaskRunToRestart = originalTaskRunToRestart
.stream()
.filter(throwPredicate(s -> {
TaskRun taskRun = execution.findTaskRunByTaskRunId(s);
Task task = flow.findTaskByTaskId(taskRun.getTaskId());
return !(task instanceof Worker);
}))
.collect(Collectors.toSet());

// we removed task with parent and no more child
Set<String> clonedLambda = finalTaskRunToRestart;
finalTaskRunToRestart = finalTaskRunToRestart
.stream()
.filter(throwPredicate(s -> {
TaskRun taskRun = execution.findTaskRunByTaskRunId(s);
return taskRun.getParentTaskRunId() == null || clonedLambda.contains(taskRun.getParentTaskRunId());
}))
.collect(Collectors.toSet());


if (finalTaskRunToRestart.size() == 0) {
if (originalTaskRunToRestart.size() > 0) {
throw new IllegalArgumentException("No valid task to restart execution from! Worker task can't be restarted.");
} else {
throw new IllegalArgumentException("No failed task found to restart execution from!");
}
}

return finalTaskRunToRestart;
}

public Execution replay(final Execution execution, String taskRunId, @Nullable Integer revision) throws Exception {
if (!execution.getState().isTerninated()) {
throw new IllegalStateException("Execution must be terminated to be restarted, " +
Expand All @@ -87,19 +128,12 @@ public Execution replay(final Execution execution, String taskRunId, @Nullable I
final Flow flow = flowRepositoryInterface.findByExecution(execution);
GraphCluster graphCluster = GraphService.of(flow, execution);

Set<String> taskRunToRestart = this.taskRunWithAncestors(
Set<String> taskRunToRestart = this.taskRunToRestart(
execution,
execution
.getTaskRunList()
.stream()
.filter(taskRun -> taskRun.getId().equals(taskRunId))
.collect(Collectors.toList())
flow,
taskRun -> taskRun.getId().equals(taskRunId)
);

if (taskRunToRestart.size() == 0) {
throw new IllegalArgumentException("No task found to restart execution from !");
}

Map<String, String> mappingTaskRunId = this.mapTaskRunId(execution, false);
final String newExecutionId = IdUtils.create();

Expand Down
3 changes: 2 additions & 1 deletion core/src/main/java/io/kestra/core/tasks/flows/Worker.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.kestra.core.models.executions.NextTaskRun;
import io.kestra.core.models.executions.TaskRun;
import io.kestra.core.models.flows.State;
import io.kestra.core.models.tasks.DynamicTask;
import io.kestra.core.models.tasks.ResolvedTask;
import io.kestra.core.models.tasks.Task;
import io.kestra.core.runners.RunContext;
Expand Down Expand Up @@ -58,7 +59,7 @@
)
}
)
public class Worker extends Sequential {
public class Worker extends Sequential implements DynamicTask {
@Override
public List<NextTaskRun> resolveNexts(RunContext runContext, Execution execution, TaskRun parentTaskRun) throws IllegalVariableEvaluationException {
List<ResolvedTask> childTasks = this.childTasks(runContext, parentTaskRun);
Expand Down
18 changes: 11 additions & 7 deletions core/src/main/java/io/kestra/core/tasks/scripts/AbstractBash.java
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,7 @@ protected ScriptOutput run(RunContext runContext, Supplier<String> supplier) thr
workingDirectory,
finalCommandsWithInterpreter(commandAsString),
this.finalEnv(),
(inputStream, isStdErr) -> {
AbstractLogThread thread = new LogThread(logger, inputStream, isStdErr, runContext);
thread.setName("bash-log-" + (isStdErr ? "-err" : "-out"));
thread.start();

return thread;
}
this.defaultLogSupplier(logger, runContext)
);

// upload output files
Expand All @@ -234,6 +228,16 @@ protected ScriptOutput run(RunContext runContext, Supplier<String> supplier) thr
.build();
}

protected LogSupplier defaultLogSupplier(Logger logger, RunContext runContext) {
return (inputStream, isStdErr) -> {
AbstractLogThread thread = new LogThread(logger, inputStream, isStdErr, runContext);
thread.setName("bash-log-" + (isStdErr ? "-err" : "-out"));
thread.start();

return thread;
};
}

protected RunResult run(RunContext runContext, Logger logger, Path workingDirectory, List<String> commandsWithInterpreter, Map<String, String> env, LogSupplier logSupplier) throws Exception {
ScriptRunnerInterface executor;
if (this.runner == Runner.DOCKER) {
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit e05d9d6

Please sign in to comment.