Search in sources :

Example 1 with TaskNode

use of org.flyte.api.v1.TaskNode in project flytekit-java by flyteorg.

the class ProtoUtilTest method createNode.

private Node createNode(String id) {
    String taskName = "task-" + id;
    String version = "version-" + id;
    String input_name = "input-name-" + id;
    String input_scalar = "input-scalar-" + id;
    TaskNode taskNode = TaskNode.builder().referenceId(PartialTaskIdentifier.builder().domain(DOMAIN).project(PROJECT).name(taskName).version(version).build()).build();
    List<Binding> inputs = singletonList(Binding.builder().var_(input_name).binding(BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofStringValue(input_scalar)))).build());
    return Node.builder().id(id).taskNode(taskNode).inputs(inputs).upstreamNodeIds(emptyList()).build();
}
Also used : Binding(org.flyte.api.v1.Binding) TaskNode(org.flyte.api.v1.TaskNode) Matchers.containsString(org.hamcrest.Matchers.containsString)

Example 2 with TaskNode

use of org.flyte.api.v1.TaskNode in project flytekit-java by flyteorg.

the class IdentifierRewriteTest method shouldRewriteBranchNodes.

@Test
void shouldRewriteBranchNodes() {
    ComparisonExpression comparison = ComparisonExpression.builder().operator(ComparisonExpression.Operator.EQ).leftValue(Operand.ofVar("a")).rightValue(Operand.ofVar("b")).build();
    BooleanExpression condition = BooleanExpression.ofComparison(comparison);
    PartialTaskIdentifier partialReference = PartialTaskIdentifier.builder().name("task-name").build();
    PartialTaskIdentifier rewrittenReference = PartialTaskIdentifier.builder().name("task-name").domain("rewritten-domain").version("rewritten-version").project("rewritten-project").build();
    TaskNode partialTaskNode = TaskNode.builder().referenceId(partialReference).build();
    TaskNode rewrittenTaskNode = TaskNode.builder().referenceId(rewrittenReference).build();
    Node partialNode = Node.builder().id("node-1").inputs(ImmutableList.of()).upstreamNodeIds(ImmutableList.of()).taskNode(partialTaskNode).build();
    Node rewrittenNode = Node.builder().id("node-1").inputs(ImmutableList.of()).upstreamNodeIds(ImmutableList.of()).taskNode(rewrittenTaskNode).build();
    IfBlock partialIfBlock = IfBlock.builder().condition(condition).thenNode(partialNode).build();
    IfBlock rewrittenIfBlock = IfBlock.builder().condition(condition).thenNode(rewrittenNode).build();
    BranchNode partialBranchNode = BranchNode.builder().ifElse(IfElseBlock.builder().case_(partialIfBlock).other(ImmutableList.of(partialIfBlock)).elseNode(partialNode).build()).build();
    BranchNode rewrittenBranchNode = BranchNode.builder().ifElse(IfElseBlock.builder().case_(rewrittenIfBlock).other(ImmutableList.of(rewrittenIfBlock)).elseNode(rewrittenNode).build()).build();
    assertThat(rewriter.visitor().visitBranchNode(partialBranchNode), equalTo(rewrittenBranchNode));
}
Also used : ComparisonExpression(org.flyte.api.v1.ComparisonExpression) BooleanExpression(org.flyte.api.v1.BooleanExpression) TaskNode(org.flyte.api.v1.TaskNode) BranchNode(org.flyte.api.v1.BranchNode) TaskNode(org.flyte.api.v1.TaskNode) BranchNode(org.flyte.api.v1.BranchNode) WorkflowNode(org.flyte.api.v1.WorkflowNode) Node(org.flyte.api.v1.Node) PartialTaskIdentifier(org.flyte.api.v1.PartialTaskIdentifier) IfBlock(org.flyte.api.v1.IfBlock) Test(org.junit.jupiter.api.Test)

Example 3 with TaskNode

use of org.flyte.api.v1.TaskNode in project flytekit-java by flyteorg.

the class SdkWorkflowBuilderTest method testConditionalWorkflowIdl.

@Test
void testConditionalWorkflowIdl() {
    SdkWorkflowBuilder builder = new SdkWorkflowBuilder();
    new ConditionalWorkflow().expand(builder);
    Node caseNode = Node.builder().id("neq").taskNode(TaskNode.builder().referenceId(PartialTaskIdentifier.builder().name("org.flyte.flytekit.SdkWorkflowBuilderTest$MultiplicationTask").build()).build()).inputs(asList(Binding.builder().var_("a").binding(BindingData.ofOutputReference(OutputReference.builder().var("in").nodeId(Node.START_NODE_ID).build())).build(), Binding.builder().var_("b").binding(BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofIntegerValue(2L)))).build())).upstreamNodeIds(emptyList()).build();
    IfElseBlock ifElse = IfElseBlock.builder().case_(IfBlock.builder().condition(BooleanExpression.ofComparison(ComparisonExpression.builder().leftValue(Operand.ofVar("$0")).rightValue(Operand.ofPrimitive(Primitive.ofIntegerValue(2L))).operator(ComparisonExpression.Operator.NEQ).build())).thenNode(caseNode).build()).error(NodeError.builder().message("No cases matched").failedNodeId("square").build()).other(emptyList()).build();
    Node expectedNode = Node.builder().id("square").branchNode(BranchNode.builder().ifElse(ifElse).build()).inputs(singletonList(Binding.builder().var_("$0").binding(BindingData.ofOutputReference(OutputReference.builder().var("in").nodeId(Node.START_NODE_ID).build())).build())).upstreamNodeIds(emptyList()).build();
    WorkflowTemplate expected = WorkflowTemplate.builder().metadata(WorkflowMetadata.builder().build()).interface_(expectedInterface()).outputs(expectedOutputs()).nodes(singletonList(expectedNode)).build();
    assertEquals(expected, builder.toIdlTemplate());
}
Also used : IfElseBlock(org.flyte.api.v1.IfElseBlock) WorkflowTemplate(org.flyte.api.v1.WorkflowTemplate) TaskNode(org.flyte.api.v1.TaskNode) BranchNode(org.flyte.api.v1.BranchNode) Node(org.flyte.api.v1.Node) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 4 with TaskNode

use of org.flyte.api.v1.TaskNode in project flytekit-java by flyteorg.

the class SdkTestingExecutor method execute.

public Result execute() {
    TestingSdkWorkflowBuilder builder = new TestingSdkWorkflowBuilder(fixedInputMap(), fixedInputTypeMap());
    workflow().expand(builder);
    WorkflowTemplate workflowTemplate = builder.toIdlTemplate();
    for (Node node : workflowTemplate.nodes()) {
        TaskNode taskNode = node.taskNode();
        if (taskNode != null) {
            String taskName = taskNode.referenceId().name();
            checkArgument(fixedTaskMap().containsKey(taskName), "Can't execute remote task [%s], " + "use SdkTestingExecutor#withTaskOutput or SdkTestingExecutor#withTask", taskName);
        }
    }
    Map<String, Literal> outputLiteralMap = LocalEngine.compileAndExecute(workflowTemplate, unmodifiableMap(fixedTaskMap()), emptyMap(), fixedInputMap());
    Map<String, LiteralType> outputLiteralTypeMap = workflowTemplate.interface_().outputs().entrySet().stream().collect(toMap(Map.Entry::getKey, x -> x.getValue().literalType()));
    return Result.create(outputLiteralMap, outputLiteralTypeMap);
}
Also used : TaskNode(org.flyte.api.v1.TaskNode) SdkRunnableTask(org.flyte.flytekit.SdkRunnableTask) HashMap(java.util.HashMap) Function(java.util.function.Function) ArrayList(java.util.ArrayList) Collectors.toMap(java.util.stream.Collectors.toMap) WorkflowTemplate(org.flyte.api.v1.WorkflowTemplate) SdkType(org.flyte.flytekit.SdkType) Duration(java.time.Duration) Map(java.util.Map) Preconditions.checkArgument(org.flyte.flytekit.testing.Preconditions.checkArgument) Collections.emptyMap(java.util.Collections.emptyMap) Node(org.flyte.api.v1.Node) LiteralType(org.flyte.api.v1.LiteralType) Literal(org.flyte.api.v1.Literal) ServiceLoader(java.util.ServiceLoader) Variable(org.flyte.api.v1.Variable) Instant(java.time.Instant) SdkWorkflow(org.flyte.flytekit.SdkWorkflow) Var(com.google.errorprone.annotations.Var) List(java.util.List) SdkRemoteTask(org.flyte.flytekit.SdkRemoteTask) AutoValue(com.google.auto.value.AutoValue) Collections.unmodifiableMap(java.util.Collections.unmodifiableMap) LocalEngine(org.flyte.localengine.LocalEngine) TaskNode(org.flyte.api.v1.TaskNode) WorkflowTemplate(org.flyte.api.v1.WorkflowTemplate) TaskNode(org.flyte.api.v1.TaskNode) Node(org.flyte.api.v1.Node) Literal(org.flyte.api.v1.Literal) LiteralType(org.flyte.api.v1.LiteralType) HashMap(java.util.HashMap) Collectors.toMap(java.util.stream.Collectors.toMap) Map(java.util.Map) Collections.emptyMap(java.util.Collections.emptyMap) Collections.unmodifiableMap(java.util.Collections.unmodifiableMap)

Example 5 with TaskNode

use of org.flyte.api.v1.TaskNode in project flytekit-java by flyteorg.

the class FlyteAdminClientTest method shouldPropagateCreateWorkflowToStub.

@Test
public void shouldPropagateCreateWorkflowToStub() {
    String nodeId = "node";
    WorkflowIdentifier identifier = WorkflowIdentifier.builder().domain(DOMAIN).project(PROJECT).name(WF_NAME).version(WF_VERSION).build();
    TaskNode taskNode = TaskNode.builder().referenceId(PartialTaskIdentifier.builder().domain(DOMAIN).project(PROJECT).name(TASK_NAME).version(TASK_VERSION).build()).build();
    Node node = Node.builder().id(nodeId).taskNode(taskNode).inputs(ImmutableList.of(Binding.builder().var_(VAR_NAME).binding(BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofStringValue(SCALAR)))).build())).upstreamNodeIds(emptyList()).build();
    TypedInterface interface_ = TypedInterface.builder().inputs(ImmutableMap.of()).outputs(ImmutableMap.of()).build();
    WorkflowTemplate template = WorkflowTemplate.builder().nodes(ImmutableList.of(node)).metadata(WorkflowMetadata.builder().build()).interface_(interface_).outputs(ImmutableList.of()).build();
    client.createWorkflow(identifier, template, ImmutableMap.of());
    assertThat(stubService.createWorkflowRequest, equalTo(WorkflowOuterClass.WorkflowCreateRequest.newBuilder().setId(newIdentifier(ResourceType.WORKFLOW, WF_NAME, WF_VERSION)).setSpec(newWorkflowSpec(nodeId)).build()));
}
Also used : TypedInterface(org.flyte.api.v1.TypedInterface) PartialWorkflowIdentifier(org.flyte.api.v1.PartialWorkflowIdentifier) WorkflowIdentifier(org.flyte.api.v1.WorkflowIdentifier) TaskNode(org.flyte.api.v1.TaskNode) WorkflowTemplate(org.flyte.api.v1.WorkflowTemplate) TaskNode(org.flyte.api.v1.TaskNode) Node(org.flyte.api.v1.Node) Test(org.junit.Test)

Aggregations

TaskNode (org.flyte.api.v1.TaskNode)8 Node (org.flyte.api.v1.Node)7 WorkflowTemplate (org.flyte.api.v1.WorkflowTemplate)4 Test (org.junit.jupiter.api.Test)4 BranchNode (org.flyte.api.v1.BranchNode)3 List (java.util.List)2 Map (java.util.Map)2 Binding (org.flyte.api.v1.Binding)2 PartialTaskIdentifier (org.flyte.api.v1.PartialTaskIdentifier)2 Variable (org.flyte.api.v1.Variable)2 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)2 AutoValue (com.google.auto.value.AutoValue)1 Var (com.google.errorprone.annotations.Var)1 Duration (java.time.Duration)1 Instant (java.time.Instant)1 ArrayList (java.util.ArrayList)1 Collections.emptyMap (java.util.Collections.emptyMap)1 Collections.unmodifiableMap (java.util.Collections.unmodifiableMap)1 HashMap (java.util.HashMap)1 ServiceLoader (java.util.ServiceLoader)1