use of org.flyte.api.v1.TaskNode in project flytekit-java by flyteorg.
the class ProtoUtilTest method createNode.
private Node createNode(String id) {
String taskName = "task-" + id;
String version = "version-" + id;
String input_name = "input-name-" + id;
String input_scalar = "input-scalar-" + id;
TaskNode taskNode = TaskNode.builder().referenceId(PartialTaskIdentifier.builder().domain(DOMAIN).project(PROJECT).name(taskName).version(version).build()).build();
List<Binding> inputs = singletonList(Binding.builder().var_(input_name).binding(BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofStringValue(input_scalar)))).build());
return Node.builder().id(id).taskNode(taskNode).inputs(inputs).upstreamNodeIds(emptyList()).build();
}
use of org.flyte.api.v1.TaskNode in project flytekit-java by flyteorg.
the class IdentifierRewriteTest method shouldRewriteBranchNodes.
@Test
void shouldRewriteBranchNodes() {
ComparisonExpression comparison = ComparisonExpression.builder().operator(ComparisonExpression.Operator.EQ).leftValue(Operand.ofVar("a")).rightValue(Operand.ofVar("b")).build();
BooleanExpression condition = BooleanExpression.ofComparison(comparison);
PartialTaskIdentifier partialReference = PartialTaskIdentifier.builder().name("task-name").build();
PartialTaskIdentifier rewrittenReference = PartialTaskIdentifier.builder().name("task-name").domain("rewritten-domain").version("rewritten-version").project("rewritten-project").build();
TaskNode partialTaskNode = TaskNode.builder().referenceId(partialReference).build();
TaskNode rewrittenTaskNode = TaskNode.builder().referenceId(rewrittenReference).build();
Node partialNode = Node.builder().id("node-1").inputs(ImmutableList.of()).upstreamNodeIds(ImmutableList.of()).taskNode(partialTaskNode).build();
Node rewrittenNode = Node.builder().id("node-1").inputs(ImmutableList.of()).upstreamNodeIds(ImmutableList.of()).taskNode(rewrittenTaskNode).build();
IfBlock partialIfBlock = IfBlock.builder().condition(condition).thenNode(partialNode).build();
IfBlock rewrittenIfBlock = IfBlock.builder().condition(condition).thenNode(rewrittenNode).build();
BranchNode partialBranchNode = BranchNode.builder().ifElse(IfElseBlock.builder().case_(partialIfBlock).other(ImmutableList.of(partialIfBlock)).elseNode(partialNode).build()).build();
BranchNode rewrittenBranchNode = BranchNode.builder().ifElse(IfElseBlock.builder().case_(rewrittenIfBlock).other(ImmutableList.of(rewrittenIfBlock)).elseNode(rewrittenNode).build()).build();
assertThat(rewriter.visitor().visitBranchNode(partialBranchNode), equalTo(rewrittenBranchNode));
}
use of org.flyte.api.v1.TaskNode in project flytekit-java by flyteorg.
the class SdkWorkflowBuilderTest method testConditionalWorkflowIdl.
@Test
void testConditionalWorkflowIdl() {
SdkWorkflowBuilder builder = new SdkWorkflowBuilder();
new ConditionalWorkflow().expand(builder);
Node caseNode = Node.builder().id("neq").taskNode(TaskNode.builder().referenceId(PartialTaskIdentifier.builder().name("org.flyte.flytekit.SdkWorkflowBuilderTest$MultiplicationTask").build()).build()).inputs(asList(Binding.builder().var_("a").binding(BindingData.ofOutputReference(OutputReference.builder().var("in").nodeId(Node.START_NODE_ID).build())).build(), Binding.builder().var_("b").binding(BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofIntegerValue(2L)))).build())).upstreamNodeIds(emptyList()).build();
IfElseBlock ifElse = IfElseBlock.builder().case_(IfBlock.builder().condition(BooleanExpression.ofComparison(ComparisonExpression.builder().leftValue(Operand.ofVar("$0")).rightValue(Operand.ofPrimitive(Primitive.ofIntegerValue(2L))).operator(ComparisonExpression.Operator.NEQ).build())).thenNode(caseNode).build()).error(NodeError.builder().message("No cases matched").failedNodeId("square").build()).other(emptyList()).build();
Node expectedNode = Node.builder().id("square").branchNode(BranchNode.builder().ifElse(ifElse).build()).inputs(singletonList(Binding.builder().var_("$0").binding(BindingData.ofOutputReference(OutputReference.builder().var("in").nodeId(Node.START_NODE_ID).build())).build())).upstreamNodeIds(emptyList()).build();
WorkflowTemplate expected = WorkflowTemplate.builder().metadata(WorkflowMetadata.builder().build()).interface_(expectedInterface()).outputs(expectedOutputs()).nodes(singletonList(expectedNode)).build();
assertEquals(expected, builder.toIdlTemplate());
}
use of org.flyte.api.v1.TaskNode in project flytekit-java by flyteorg.
the class SdkTestingExecutor method execute.
public Result execute() {
TestingSdkWorkflowBuilder builder = new TestingSdkWorkflowBuilder(fixedInputMap(), fixedInputTypeMap());
workflow().expand(builder);
WorkflowTemplate workflowTemplate = builder.toIdlTemplate();
for (Node node : workflowTemplate.nodes()) {
TaskNode taskNode = node.taskNode();
if (taskNode != null) {
String taskName = taskNode.referenceId().name();
checkArgument(fixedTaskMap().containsKey(taskName), "Can't execute remote task [%s], " + "use SdkTestingExecutor#withTaskOutput or SdkTestingExecutor#withTask", taskName);
}
}
Map<String, Literal> outputLiteralMap = LocalEngine.compileAndExecute(workflowTemplate, unmodifiableMap(fixedTaskMap()), emptyMap(), fixedInputMap());
Map<String, LiteralType> outputLiteralTypeMap = workflowTemplate.interface_().outputs().entrySet().stream().collect(toMap(Map.Entry::getKey, x -> x.getValue().literalType()));
return Result.create(outputLiteralMap, outputLiteralTypeMap);
}
use of org.flyte.api.v1.TaskNode in project flytekit-java by flyteorg.
the class FlyteAdminClientTest method shouldPropagateCreateWorkflowToStub.
@Test
public void shouldPropagateCreateWorkflowToStub() {
String nodeId = "node";
WorkflowIdentifier identifier = WorkflowIdentifier.builder().domain(DOMAIN).project(PROJECT).name(WF_NAME).version(WF_VERSION).build();
TaskNode taskNode = TaskNode.builder().referenceId(PartialTaskIdentifier.builder().domain(DOMAIN).project(PROJECT).name(TASK_NAME).version(TASK_VERSION).build()).build();
Node node = Node.builder().id(nodeId).taskNode(taskNode).inputs(ImmutableList.of(Binding.builder().var_(VAR_NAME).binding(BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofStringValue(SCALAR)))).build())).upstreamNodeIds(emptyList()).build();
TypedInterface interface_ = TypedInterface.builder().inputs(ImmutableMap.of()).outputs(ImmutableMap.of()).build();
WorkflowTemplate template = WorkflowTemplate.builder().nodes(ImmutableList.of(node)).metadata(WorkflowMetadata.builder().build()).interface_(interface_).outputs(ImmutableList.of()).build();
client.createWorkflow(identifier, template, ImmutableMap.of());
assertThat(stubService.createWorkflowRequest, equalTo(WorkflowOuterClass.WorkflowCreateRequest.newBuilder().setId(newIdentifier(ResourceType.WORKFLOW, WF_NAME, WF_VERSION)).setSpec(newWorkflowSpec(nodeId)).build()));
}
Aggregations