diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/ContainerTask.java b/flytekit-api/src/main/java/org/flyte/api/v1/ContainerTask.java index 3332417b0..0a3143f5e 100644 --- a/flytekit-api/src/main/java/org/flyte/api/v1/ContainerTask.java +++ b/flytekit-api/src/main/java/org/flyte/api/v1/ContainerTask.java @@ -16,15 +16,10 @@ */ package org.flyte.api.v1; -import static java.util.Collections.emptyMap; - import java.util.List; /** Building block for tasks that execute arbitrary containers. */ -public interface ContainerTask { - - /** Specifies task name. */ - String getName(); +public interface ContainerTask extends Task { /** Specifies container image. */ String getImage(); @@ -38,22 +33,13 @@ public interface ContainerTask { /** Specifies container environment variables. */ List getEnv(); + @Override default String getType() { return "raw-container"; } - TypedInterface getInterface(); - /** Specifies container resource requests. */ default Resources getResources() { return Resources.builder().build(); } - - /** Specifies task retry policy. */ - RetryStrategy getRetries(); - - /** Specifies custom container parameters. */ - default Struct getCustom() { - return Struct.of(emptyMap()); - } } diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/RunnableTask.java b/flytekit-api/src/main/java/org/flyte/api/v1/RunnableTask.java index 8b6f07b12..e78f07da2 100644 --- a/flytekit-api/src/main/java/org/flyte/api/v1/RunnableTask.java +++ b/flytekit-api/src/main/java/org/flyte/api/v1/RunnableTask.java @@ -21,15 +21,15 @@ import java.util.Map; /** Building block for tasks that execute Java code. */ -public interface RunnableTask { - - String getName(); +public interface RunnableTask extends Task { + @Override default String getType() { // FIXME default only for backwards-compatibility, remove in 0.3.x return "java-task"; } + @Override default Struct getCustom() { // FIXME default only for backwards-compatibility, remove in 0.3.x return Struct.of(emptyMap()); @@ -40,21 +40,5 @@ default Resources getResources() { return Resources.builder().build(); } - TypedInterface getInterface(); - Map run(Map inputs); - - RetryStrategy getRetries(); - - default boolean isCached() { - return false; - } - - default String getCacheVersion() { - return null; - } - - default boolean isCacheSerializable() { - return false; - } } diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/Task.java b/flytekit-api/src/main/java/org/flyte/api/v1/Task.java new file mode 100644 index 000000000..a738fb182 --- /dev/null +++ b/flytekit-api/src/main/java/org/flyte/api/v1/Task.java @@ -0,0 +1,61 @@ +/* + * Copyright 2021 Flyte Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.api.v1; + +import static java.util.Collections.emptyMap; + +/** Super interfaces for all tasks. */ +public interface Task { + + /** Specifies task name. */ + String getName(); + + /** Specifies the task type identifier. */ + String getType(); + + /** Specifies the task interface: inputs/outputs. */ + TypedInterface getInterface(); + + /** Specifies custom data about the task. */ + default Struct getCustom() { + return Struct.of(emptyMap()); + } + + /** Specifies task retry policy. */ + RetryStrategy getRetries(); + + /** + * Indicates whether the system should attempt to lookup this task's output to avoid duplication + * of work. + */ + default boolean isCached() { + return false; + } + + /** Indicates a logical version to apply to this task for the purpose of cache. */ + default String getCacheVersion() { + return null; + } + + /** + * Indicates whether the system should attempt to execute cached instances in serial to avoid + * duplicate work. + */ + default boolean isCacheSerializable() { + return false; + } +} diff --git a/jflyte/src/main/java/org/flyte/jflyte/ProjectClosure.java b/jflyte/src/main/java/org/flyte/jflyte/ProjectClosure.java index f62d9259d..80048a7af 100644 --- a/jflyte/src/main/java/org/flyte/jflyte/ProjectClosure.java +++ b/jflyte/src/main/java/org/flyte/jflyte/ProjectClosure.java @@ -67,6 +67,7 @@ import org.flyte.api.v1.RunnableTask; import org.flyte.api.v1.RunnableTaskRegistrar; import org.flyte.api.v1.Struct; +import org.flyte.api.v1.Task; import org.flyte.api.v1.TaskIdentifier; import org.flyte.api.v1.TaskTemplate; import org.flyte.api.v1.WorkflowIdentifier; @@ -451,6 +452,25 @@ static TaskTemplate createTaskTemplateForRunnableTask(RunnableTask task, String .resources(resources) .build(); + return createTaskTemplate(task, container); + } + + @VisibleForTesting + static TaskTemplate createTaskTemplateForContainerTask(ContainerTask task) { + Resources resources = task.getResources(); + Container container = + Container.builder() + .command(task.getCommand()) + .args(task.getArgs()) + .image(task.getImage()) + .env(task.getEnv()) + .resources(resources) + .build(); + + return createTaskTemplate(task, container); + } + + private static TaskTemplate createTaskTemplate(Task task, Container container) { TaskTemplate.Builder templateBuilder = TaskTemplate.builder() .container(container) @@ -468,27 +488,6 @@ static TaskTemplate createTaskTemplateForRunnableTask(RunnableTask task, String return templateBuilder.build(); } - @VisibleForTesting - static TaskTemplate createTaskTemplateForContainerTask(ContainerTask task) { - Resources resources = task.getResources(); - Container container = - Container.builder() - .command(task.getCommand()) - .args(task.getArgs()) - .image(task.getImage()) - .env(task.getEnv()) - .resources(resources) - .build(); - - return TaskTemplate.builder() - .container(container) - .interface_(task.getInterface()) - .retries(task.getRetries()) - .type(task.getType()) - .custom(task.getCustom()) - .build(); - } - private static Optional javaToolOptionsEnv(Resources resources) { Map limits = resources.limits(); if (limits == null || !limits.containsKey(ResourceName.MEMORY)) { diff --git a/jflyte/src/test/java/org/flyte/jflyte/ProjectClosureTest.java b/jflyte/src/test/java/org/flyte/jflyte/ProjectClosureTest.java index f65c59419..8d42e6e26 100644 --- a/jflyte/src/test/java/org/flyte/jflyte/ProjectClosureTest.java +++ b/jflyte/src/test/java/org/flyte/jflyte/ProjectClosureTest.java @@ -18,7 +18,8 @@ import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; -import static org.flyte.api.v1.Resources.ResourceName.*; +import static org.flyte.api.v1.Resources.ResourceName.CPU; +import static org.flyte.api.v1.Resources.ResourceName.MEMORY; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; @@ -29,6 +30,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.reflect.Reflection; import com.google.protobuf.ByteString; import java.util.HashMap; import java.util.List; @@ -52,6 +54,7 @@ import org.flyte.api.v1.RetryStrategy; import org.flyte.api.v1.RunnableTask; import org.flyte.api.v1.Struct; +import org.flyte.api.v1.Task; import org.flyte.api.v1.TaskTemplate; import org.flyte.api.v1.TypedInterface; import org.flyte.api.v1.WorkflowIdentifier; @@ -59,7 +62,6 @@ import org.flyte.api.v1.WorkflowNode; import org.flyte.api.v1.WorkflowTemplate; import org.flyte.flytekit.SdkTypes; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; public class ProjectClosureTest { @@ -430,9 +432,6 @@ public void testCreateTaskTemplateForRunnableTask() { assertThat(result.custom(), equalTo(Struct.of(emptyMap()))); assertThat(result.retries(), equalTo(RetryStrategy.builder().retries(0).build())); assertThat(result.type(), equalTo("java-task")); - assertThat(result.discoverable(), equalTo(false)); - assertThat(result.cacheSerializable(), equalTo(false)); - assertThat(result.discoveryVersion(), nullValue()); } @Test @@ -468,44 +467,9 @@ public void testCreateTaskTemplateForRunnableTaskWithResources() { assertThat(result.custom(), equalTo(Struct.of(emptyMap()))); assertThat(result.retries(), equalTo(RetryStrategy.builder().retries(0).build())); assertThat(result.type(), equalTo("java-task")); - assertThat(result.discoverable(), equalTo(false)); - assertThat(result.cacheSerializable(), equalTo(false)); - assertThat(result.discoveryVersion(), nullValue()); } @Test - public void testCreateTaskTemplateForRunnableTaskWithCache() { - // given - RunnableTask task = createRunnableTaskWithCache(); - String image = "my-image"; - Resources expectedResources = Resources.builder().build(); - - // when - TaskTemplate result = ProjectClosure.createTaskTemplateForRunnableTask(task, image); - - // then - Container container = result.container(); - assertNotNull(container); - assertThat(container.image(), equalTo(image)); - assertThat(container.resources(), equalTo(expectedResources)); - assertThat(container.env(), equalTo(emptyList())); - assertThat( - result.interface_(), - equalTo( - TypedInterface.builder() - .inputs(SdkTypes.nulls().getVariableMap()) - .outputs(SdkTypes.nulls().getVariableMap()) - .build())); - assertThat(result.custom(), equalTo(Struct.of(emptyMap()))); - assertThat(result.retries(), equalTo(RetryStrategy.builder().retries(0).build())); - assertThat(result.type(), equalTo("java-task")); - assertThat(result.discoverable(), equalTo(true)); - assertThat(result.cacheSerializable(), equalTo(true)); - assertThat(result.discoveryVersion(), equalTo("0.0.1")); - } - - @Test - @Disabled("ContainerTasks don't currently support cache, fix coming up") public void testCreateTaskTemplateForContainerTask() { // given Resources expectedResources = @@ -542,6 +506,62 @@ public void testCreateTaskTemplateForContainerTask() { assertThat(result.type(), equalTo("raw-container")); } + @Test + public void testCreateTaskTemplateForTasksWithNoCache() { + // given + RunnableTask runnableTask = createRunnableTask(null); + ContainerTask containerTask = + createContainerTask( + Resources.builder().build(), + "test-image", + emptyList(), + ImmutableList.of("program"), + emptyList()); + + // when + TaskTemplate runnableTaskTemplate = + ProjectClosure.createTaskTemplateForRunnableTask(runnableTask, "image"); + TaskTemplate containerTakTemplate = + ProjectClosure.createTaskTemplateForContainerTask(containerTask); + List taskTemplates = ImmutableList.of(runnableTaskTemplate, containerTakTemplate); + + // then + for (TaskTemplate taskTemplate : taskTemplates) { + assertThat(taskTemplate.discoverable(), equalTo(false)); + assertThat(taskTemplate.cacheSerializable(), equalTo(false)); + assertThat(taskTemplate.discoveryVersion(), nullValue()); + } + } + + @Test + public void testCreateTaskTemplateForTasksWithCache() { + // given + RunnableTask runnableTask = wrapTaskWithRetries(RunnableTask.class, createRunnableTask(null)); + ContainerTask containerTask = + wrapTaskWithRetries( + ContainerTask.class, + createContainerTask( + Resources.builder().build(), + "test-image", + emptyList(), + ImmutableList.of("program"), + emptyList())); + + // when + TaskTemplate runnableTaskTemplate = + ProjectClosure.createTaskTemplateForRunnableTask(runnableTask, "image"); + TaskTemplate containerTakTemplate = + ProjectClosure.createTaskTemplateForContainerTask(containerTask); + List taskTemplates = ImmutableList.of(runnableTaskTemplate, containerTakTemplate); + + // then + for (TaskTemplate taskTemplate : taskTemplates) { + assertThat(taskTemplate.discoverable(), equalTo(true)); + assertThat(taskTemplate.cacheSerializable(), equalTo(true)); + assertThat(taskTemplate.discoveryVersion(), equalTo("0.0.1")); + } + } + private RunnableTask createRunnableTask(Resources expectedResources) { return new RunnableTask() { @Override @@ -631,46 +651,20 @@ public List getEnv() { }; } - private RunnableTask createRunnableTaskWithCache() { - return new RunnableTask() { - @Override - public String getName() { - return "my-test-task"; - } - - @Override - public TypedInterface getInterface() { - return TypedInterface.builder() - .inputs(SdkTypes.nulls().getVariableMap()) - .outputs(SdkTypes.nulls().getVariableMap()) - .build(); - } - - @Override - public Map run(Map inputs) { - System.out.println("Cached Hello World"); - return null; - } - - @Override - public RetryStrategy getRetries() { - return RetryStrategy.builder().retries(0).build(); - } - - @Override - public boolean isCached() { - return true; - } - - @Override - public String getCacheVersion() { - return "0.0.1"; - } - - @Override - public boolean isCacheSerializable() { - return true; - } - }; + private T wrapTaskWithRetries(Class taskClass, T task) { + return Reflection.newProxy( + taskClass, + (proxy, method, methodArgs) -> { + switch (method.getName()) { + case "isCached": + case "isCacheSerializable": + return true; + + case "getCacheVersion": + return "0.0.1"; + default: + return method.invoke(task, methodArgs); + } + }); } }