Search in sources :

Example 1 with START_NODE_ID

use of org.flyte.api.v1.Node.START_NODE_ID in project flytekit-java by flyteorg.

the class ExecutionNodeCompiler method compile.

static ExecutionNode compile(Node node, Map<String, RunnableTask> runnableTasks, Map<String, DynamicWorkflowTask> dynamicWorkflowTasks) {
    List<String> upstreamNodeIds = new ArrayList<>();
    node.inputs().stream().map(Binding::binding).flatMap(ExecutionNodeCompiler::unpackBindingData).filter(x -> x.kind() == BindingData.Kind.PROMISE).map(x -> x.promise().nodeId()).forEach(upstreamNodeIds::add);
    upstreamNodeIds.addAll(node.upstreamNodeIds());
    if (upstreamNodeIds.isEmpty()) {
        upstreamNodeIds.add(START_NODE_ID);
    }
    if (node.branchNode() != null) {
        throw new IllegalArgumentException("BranchNode isn't yet supported for local execution");
    }
    if (node.workflowNode() != null) {
        throw new IllegalArgumentException("WorkflowNode isn't yet supported for local execution");
    }
    String taskName = node.taskNode().referenceId().name();
    DynamicWorkflowTask dynamicWorkflowTask = dynamicWorkflowTasks.get(taskName);
    RunnableTask runnableTask = runnableTasks.get(taskName);
    if (dynamicWorkflowTask != null) {
        throw new IllegalArgumentException("DynamicWorkflowTask isn't yet supported for local execution");
    }
    Objects.requireNonNull(runnableTask, () -> String.format("Couldn't find task [%s]", taskName));
    int attempts = runnableTask.getRetries().retries() + 1;
    return ExecutionNode.builder().nodeId(node.id()).bindings(node.inputs()).runnableTask(runnableTask).upstreamNodeIds(upstreamNodeIds).attempts(attempts).build();
}
Also used : Binding(org.flyte.api.v1.Binding) Node(org.flyte.api.v1.Node) Collections.emptyList(java.util.Collections.emptyList) START_NODE_ID(org.flyte.api.v1.Node.START_NODE_ID) DynamicWorkflowTask(org.flyte.api.v1.DynamicWorkflowTask) Set(java.util.Set) HashMap(java.util.HashMap) Binding(org.flyte.api.v1.Binding) Deque(java.util.Deque) RunnableTask(org.flyte.api.v1.RunnableTask) Collections.singletonList(java.util.Collections.singletonList) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) Objects(java.util.Objects) Collectors.toList(java.util.stream.Collectors.toList) List(java.util.List) BindingData(org.flyte.api.v1.BindingData) Stream(java.util.stream.Stream) Map(java.util.Map) ArrayDeque(java.util.ArrayDeque) Comparator(java.util.Comparator) DynamicWorkflowTask(org.flyte.api.v1.DynamicWorkflowTask) ArrayList(java.util.ArrayList) RunnableTask(org.flyte.api.v1.RunnableTask)

Example 2 with START_NODE_ID

use of org.flyte.api.v1.Node.START_NODE_ID in project flytekit-java by flyteorg.

the class ExecutionNodeCompilerTest method testCompile_unknownTask.

@Test
void testCompile_unknownTask() {
    Node node = createNode("node-1", ImmutableList.of(START_NODE_ID));
    RuntimeException exception = assertThrows(RuntimeException.class, () -> ExecutionNodeCompiler.compile(node, emptyMap(), emptyMap()));
    assertEquals("Couldn't find task [unknownTask]", exception.getMessage());
}
Also used : Node(org.flyte.api.v1.Node) TaskNode(org.flyte.api.v1.TaskNode) Test(org.junit.jupiter.api.Test)

Aggregations

Node (org.flyte.api.v1.Node)2 ArrayDeque (java.util.ArrayDeque)1 ArrayList (java.util.ArrayList)1 Collections.emptyList (java.util.Collections.emptyList)1 Collections.singletonList (java.util.Collections.singletonList)1 Comparator (java.util.Comparator)1 Deque (java.util.Deque)1 HashMap (java.util.HashMap)1 HashSet (java.util.HashSet)1 List (java.util.List)1 Map (java.util.Map)1 Objects (java.util.Objects)1 Set (java.util.Set)1 Collectors.toList (java.util.stream.Collectors.toList)1 Stream (java.util.stream.Stream)1 Binding (org.flyte.api.v1.Binding)1 BindingData (org.flyte.api.v1.BindingData)1 DynamicWorkflowTask (org.flyte.api.v1.DynamicWorkflowTask)1 START_NODE_ID (org.flyte.api.v1.Node.START_NODE_ID)1 RunnableTask (org.flyte.api.v1.RunnableTask)1