Skip to content

Commit

Permalink
feat(core): allow taskDefault to use a prefix (and not full class nam…
Browse files Browse the repository at this point in the history
…e) (#2657)
  • Loading branch information
tchiotludo authored Dec 8, 2023
1 parent 6699436 commit f5c4aab
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 67 deletions.
12 changes: 10 additions & 2 deletions core/src/main/java/io/kestra/core/serializers/YamlFlowParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,17 @@ public <T> T parse(String input, Class<T> cls) {
return readFlow(mapper, input, cls, type(cls));
}

public <T> T parse(Map<String, Object> input, Class<T> cls) {

public <T> T parse(Map<String, Object> input, Class<T> cls, Boolean strict) {
ObjectMapper currentMapper = mapper;

if (!strict) {
currentMapper = mapper.copy()
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
}

try {
return mapper.convertValue(input, cls);
return currentMapper.convertValue(input, cls);
} catch (IllegalArgumentException e) {
if(e.getCause() instanceof JsonProcessingException jsonProcessingException) {
jsonProcessingExceptionHandler(input, type(cls), jsonProcessingException);
Expand Down
24 changes: 13 additions & 11 deletions core/src/main/java/io/kestra/core/services/TaskDefaultService.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,24 +108,23 @@ public Flow injectDefaults(Flow flow) throws ConstraintViolationException {
}

// we apply default and overwrite with forced
if (defaults.size() > 0) {
if (!defaults.isEmpty()) {
flowAsMap = (Map<String, Object>) recursiveDefaults(flowAsMap, defaults);
}

if (forced.size() > 0) {
if (!forced.isEmpty()) {
flowAsMap = (Map<String, Object>) recursiveDefaults(flowAsMap, forced);
}

if (taskDefaults != null) {
flowAsMap.put("taskDefaults", taskDefaults);
}

return yamlFlowParser.parse(flowAsMap, Flow.class);
return yamlFlowParser.parse(flowAsMap, Flow.class, false);
}

private static Object recursiveDefaults(Object object, Map<String, List<TaskDefault>> defaults) {
if (object instanceof Map) {
Map<?, ?> value = (Map<?, ?>) object;
if (object instanceof Map<?, ?> value) {
if (value.containsKey("type")) {
value = defaults(value, defaults);
}
Expand All @@ -138,8 +137,7 @@ private static Object recursiveDefaults(Object object, Map<String, List<TaskDefa
recursiveDefaults(e.getValue(), defaults)
))
.collect(HashMap::new, (m, v) -> m.put(v.getKey(), v.getValue()), HashMap::putAll);
} else if (object instanceof Collection) {
Collection<?> value = (Collection<?>) object;
} else if (object instanceof Collection<?> value) {
return value
.stream()
.map(r -> recursiveDefaults(r, defaults))
Expand All @@ -152,19 +150,23 @@ private static Object recursiveDefaults(Object object, Map<String, List<TaskDefa
@SuppressWarnings("unchecked")
protected static Map<?, ?> defaults(Map<?, ?> task, Map<String, List<TaskDefault>> defaults) {
Object type = task.get("type");
if (!(type instanceof String)) {
if (!(type instanceof String taskType)) {
return task;
}

String taskType = (String) type;
List<TaskDefault> matching = defaults.entrySet()
.stream()
.filter(e -> e.getKey().equals(taskType) || taskType.startsWith(e.getKey()))
.flatMap(e -> e.getValue().stream())
.toList();

if (!defaults.containsKey(taskType)) {
if (matching.isEmpty()) {
return task;
}

Map<String, Object> result = (Map<String, Object>) task;

for (TaskDefault taskDefault : defaults.get(taskType)) {
for (TaskDefault taskDefault : matching) {
if (taskDefault.isForced()) {
result = MapUtils.merge(result, taskDefault.getValues());
} else {
Expand Down
6 changes: 0 additions & 6 deletions core/src/test/java/io/kestra/core/runners/RunContextTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,6 @@ void largeInput() throws IOException, InterruptedException {
assertThat(storageInterface.size(null, uri), is(size + 1));
}

@Test
void invalidTaskDefaults() throws TimeoutException, IOException, URISyntaxException {
repositoryLoader.load(Objects.requireNonNull(ListenersTest.class.getClassLoader().getResource("flows/tests/invalid-task-defaults.yaml")));
taskDefaultsCaseTest.invalidTaskDefaults();
}

@Test
void metricsIncrement() {
RunContext runContext = runContextFactory.of();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,59 +2,39 @@

import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.executions.Execution;
import io.kestra.core.models.executions.LogEntry;
import io.kestra.core.models.executions.NextTaskRun;
import io.kestra.core.models.executions.TaskRun;
import io.kestra.core.models.hierarchies.GraphCluster;
import io.kestra.core.models.hierarchies.RelationType;
import io.kestra.core.models.tasks.FlowableTask;
import io.kestra.core.models.tasks.ResolvedTask;
import io.kestra.core.models.tasks.Task;
import io.kestra.core.queues.QueueFactoryInterface;
import io.kestra.core.queues.QueueInterface;
import io.kestra.core.repositories.FlowRepositoryInterface;
import io.kestra.core.services.TaskDefaultService;
import io.kestra.core.utils.GraphUtils;
import io.kestra.core.utils.TestsUtils;
import jakarta.inject.Inject;
import jakarta.inject.Named;
import jakarta.inject.Singleton;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.ToString;
import lombok.experimental.SuperBuilder;

import javax.validation.ConstraintViolationException;
import javax.validation.Valid;
import javax.validation.constraints.NotEmpty;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.validation.Valid;
import javax.validation.constraints.NotEmpty;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.*;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;

@Singleton
public class TaskDefaultsCaseTest {
@Inject
private RunnerUtils runnerUtils;

@Inject
private TaskDefaultService taskDefaultService;

@Inject
private FlowRepositoryInterface flowRepository;

@Inject
@Named(QueueFactoryInterface.WORKERTASKLOG_NAMED)
private QueueInterface<LogEntry> logQueue;

public void taskDefaults() throws TimeoutException {
Execution execution = runnerUtils.runOne(null, "io.kestra.tests", "task-defaults", Duration.ofSeconds(60));

Expand All @@ -75,24 +55,6 @@ public void taskDefaults() throws TimeoutException {
assertThat(execution.getTaskRunList().get(6).getOutputs().get("def"), is("3"));
}

public void invalidTaskDefaults() throws TimeoutException {
List<LogEntry> logs = new CopyOnWriteArrayList<>();
logQueue.receive(either -> logs.add(either.getLeft()));

Execution execution = runnerUtils.runOne(null, "io.kestra.tests", "invalid-task-defaults", Duration.ofSeconds(60));

assertThat(execution.getTaskRunList(), hasSize(1));
LogEntry matchingLog = TestsUtils.awaitLog(logs, log -> log.getMessage().contains("Unrecognized field \"invalid\""));
assertThat(matchingLog, notNullValue());

ConstraintViolationException constraintViolationException = assertThrows(ConstraintViolationException.class, () -> taskDefaultService.injectDefaults(flowRepository
.findById(null, "io.kestra.tests", "invalid-task-defaults", Optional.empty())
.orElseThrow()));

assertThat(constraintViolationException.getConstraintViolations().size(), is(1));
assertThat(constraintViolationException.getMessage(), containsString("Unrecognized field \"invalid\""));
}

@SuperBuilder
@ToString
@EqualsAndHashCode
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.kestra.core.serializers;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.kestra.core.models.flows.Flow;
import io.kestra.core.models.flows.Input;
Expand All @@ -24,6 +25,7 @@
import java.time.Duration;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Map;
import java.util.Optional;

import static org.hamcrest.MatcherAssert.assertThat;
Expand Down Expand Up @@ -217,6 +219,20 @@ void invalidProperty() {
assertThat(exception.getConstraintViolations().iterator().next().getPropertyPath().toString(), is("io.kestra.core.models.flows.Flow[\"tasks\"]->java.util.ArrayList[0]->io.kestra.core.tasks.debugs.Return[\"invalid\"]"));
}

@Test
void invalidPropertyOk() throws IOException {
URL resource = TestsUtils.class.getClassLoader().getResource("flows/invalids/invalid-property.yaml");
assert resource != null;

File file = new File(resource.getFile());
String flowSource = Files.readString(file.toPath(), Charset.defaultCharset());
TypeReference<Map<String, Object>> TYPE_REFERENCE = new TypeReference<>() {};
Map<String, Object> flow = JacksonMapper.ofYaml().readValue(flowSource, TYPE_REFERENCE);

Flow parse = yamlFlowParser.parse(flow, Flow.class, false);

assertThat(parse.getId(), is("duplicate"));
}

@Test
void includeFailed() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,47 @@ public void forced() {
assertThat(((DefaultTester) injected.getTasks().get(0)).getSet(), is(123));
}

@Test
public void prefix() {
DefaultTester task = DefaultTester.builder()
.id("test")
.type(DefaultTester.class.getName())
.set(666)
.build();

Flow flow = Flow.builder()
.triggers(List.of(
DefaultTriggerTester.builder()
.id("trigger")
.type(DefaultTriggerTester.class.getName())
.conditions(List.of(VariableCondition.builder()
.type(VariableCondition.class.getName())
.build())
)
.build()
))
.tasks(Collections.singletonList(task))
.taskDefaults(List.of(
new TaskDefault(DefaultTester.class.getName(), false, ImmutableMap.of(
"set", 789
)),
new TaskDefault("io.kestra.core.services.", false, ImmutableMap.of(
"value", 2,
"set", 456,
"arrays", Collections.singletonList(1)
)),
new TaskDefault("io.kestra.core.services2.", false, ImmutableMap.of(
"value", 3
))
))
.build();

Flow injected = taskDefaultService.injectDefaults(flow);

assertThat(((DefaultTester) injected.getTasks().get(0)).getSet(), is(666));
assertThat(((DefaultTester) injected.getTasks().get(0)).getValue(), is(2));
}

@SuperBuilder
@ToString
@EqualsAndHashCode
Expand Down
6 changes: 0 additions & 6 deletions jdbc/src/test/java/io/kestra/jdbc/runner/JdbcRunnerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,6 @@ void taskDefaults() throws TimeoutException, IOException, URISyntaxException {
taskDefaultsCaseTest.taskDefaults();
}

@Test
void invalidTaskDefaults() throws TimeoutException, IOException, URISyntaxException {
repositoryLoader.load(Objects.requireNonNull(ListenersTest.class.getClassLoader().getResource("flows/tests/invalid-task-defaults.yaml")));
taskDefaultsCaseTest.invalidTaskDefaults();
}

@Test
void flowWaitSuccess() throws Exception {
flowCaseTest.waitSuccess();
Expand Down

0 comments on commit f5c4aab

Please sign in to comment.