Search in sources :

Example 6 with BranchNode

use of org.flyte.api.v1.BranchNode in project flytekit-java by flyteorg.

the class ExecutionNodeCompiler method compile.

static ExecutionNode compile(Node node, Map<String, RunnableTask> runnableTasks, Map<String, DynamicWorkflowTask> dynamicWorkflowTasks) {
    List<String> upstreamNodeIds = new ArrayList<>();
    node.inputs().stream().map(Binding::binding).flatMap(ExecutionNodeCompiler::unpackBindingData).filter(x -> x.kind() == BindingData.Kind.PROMISE).map(x -> x.promise().nodeId()).forEach(upstreamNodeIds::add);
    upstreamNodeIds.addAll(node.upstreamNodeIds());
    if (upstreamNodeIds.isEmpty()) {
        upstreamNodeIds.add(START_NODE_ID);
    }
    if (node.branchNode() != null) {
        throw new IllegalArgumentException("BranchNode isn't yet supported for local execution");
    }
    if (node.workflowNode() != null) {
        throw new IllegalArgumentException("WorkflowNode isn't yet supported for local execution");
    }
    String taskName = node.taskNode().referenceId().name();
    DynamicWorkflowTask dynamicWorkflowTask = dynamicWorkflowTasks.get(taskName);
    RunnableTask runnableTask = runnableTasks.get(taskName);
    if (dynamicWorkflowTask != null) {
        throw new IllegalArgumentException("DynamicWorkflowTask isn't yet supported for local execution");
    }
    Objects.requireNonNull(runnableTask, () -> String.format("Couldn't find task [%s]", taskName));
    int attempts = runnableTask.getRetries().retries() + 1;
    return ExecutionNode.builder().nodeId(node.id()).bindings(node.inputs()).runnableTask(runnableTask).upstreamNodeIds(upstreamNodeIds).attempts(attempts).build();
}
Also used : Binding(org.flyte.api.v1.Binding) Node(org.flyte.api.v1.Node) Collections.emptyList(java.util.Collections.emptyList) START_NODE_ID(org.flyte.api.v1.Node.START_NODE_ID) DynamicWorkflowTask(org.flyte.api.v1.DynamicWorkflowTask) Set(java.util.Set) HashMap(java.util.HashMap) Binding(org.flyte.api.v1.Binding) Deque(java.util.Deque) RunnableTask(org.flyte.api.v1.RunnableTask) Collections.singletonList(java.util.Collections.singletonList) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) Objects(java.util.Objects) Collectors.toList(java.util.stream.Collectors.toList) List(java.util.List) BindingData(org.flyte.api.v1.BindingData) Stream(java.util.stream.Stream) Map(java.util.Map) ArrayDeque(java.util.ArrayDeque) Comparator(java.util.Comparator) DynamicWorkflowTask(org.flyte.api.v1.DynamicWorkflowTask) ArrayList(java.util.ArrayList) RunnableTask(org.flyte.api.v1.RunnableTask)

Aggregations

BranchNode (org.flyte.api.v1.BranchNode)4 Node (org.flyte.api.v1.Node)4 Test (org.junit.jupiter.api.Test)4 IfElseBlock (org.flyte.api.v1.IfElseBlock)3 ArrayList (java.util.ArrayList)2 HashMap (java.util.HashMap)2 Binding (org.flyte.api.v1.Binding)2 BooleanExpression (org.flyte.api.v1.BooleanExpression)2 ComparisonExpression (org.flyte.api.v1.ComparisonExpression)2 IfBlock (org.flyte.api.v1.IfBlock)2 TaskNode (org.flyte.api.v1.TaskNode)2 WorkflowNode (org.flyte.api.v1.WorkflowNode)2 WorkflowTemplate (org.flyte.api.v1.WorkflowTemplate)2 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)2 Var (com.google.errorprone.annotations.Var)1 Condition (flyteidl.core.Condition)1 Workflow (flyteidl.core.Workflow)1 ArrayDeque (java.util.ArrayDeque)1 Collections.emptyList (java.util.Collections.emptyList)1 Collections.singletonList (java.util.Collections.singletonList)1