use of org.flyte.api.v1.Node in project flytekit-java by flyteorg.
the class ProtoUtilTest method shouldSerializeWorkflowTemplate.
@Test
void shouldSerializeWorkflowTemplate() {
Node nodeA = createNode("a").toBuilder().upstreamNodeIds(singletonList("b")).build();
Node nodeB = createNode("b").toBuilder().metadata(NodeMetadata.builder().name("fancy-b").timeout(Duration.ofMinutes(15)).retries(RetryStrategy.builder().retries(3).build()).build()).build();
;
WorkflowMetadata metadata = WorkflowMetadata.builder().build();
TypedInterface interface_ = TypedInterface.builder().inputs(emptyMap()).outputs(emptyMap()).build();
WorkflowTemplate template = WorkflowTemplate.builder().nodes(asList(nodeA, nodeB)).metadata(metadata).interface_(interface_).outputs(emptyList()).build();
Workflow.Node expectedNode1 = Workflow.Node.newBuilder().setId("a").addUpstreamNodeIds("b").setTaskNode(Workflow.TaskNode.newBuilder().setReferenceId(IdentifierOuterClass.Identifier.newBuilder().setResourceType(TASK).setDomain(DOMAIN).setProject(PROJECT).setName("task-a").setVersion("version-a").build()).build()).addInputs(Literals.Binding.newBuilder().setVar("input-name-a").setBinding(Literals.BindingData.newBuilder().setScalar(Literals.Scalar.newBuilder().setPrimitive(Literals.Primitive.newBuilder().setStringValue("input-scalar-a").build()).build()).build()).build()).build();
Workflow.Node expectedNode2 = Workflow.Node.newBuilder().setId("b").setMetadata(Workflow.NodeMetadata.newBuilder().setName("fancy-b").setTimeout(com.google.protobuf.Duration.newBuilder().setSeconds(15 * 60).build()).setRetries(Literals.RetryStrategy.newBuilder().setRetries(3).build()).build()).setTaskNode(Workflow.TaskNode.newBuilder().setReferenceId(IdentifierOuterClass.Identifier.newBuilder().setResourceType(TASK).setDomain(DOMAIN).setProject(PROJECT).setName("task-b").setVersion("version-b").build()).build()).addInputs(Literals.Binding.newBuilder().setVar("input-name-b").setBinding(Literals.BindingData.newBuilder().setScalar(Literals.Scalar.newBuilder().setPrimitive(Literals.Primitive.newBuilder().setStringValue("input-scalar-b").build()).build()).build()).build()).build();
Workflow.WorkflowTemplate serializedTemplate = ProtoUtil.serialize(template);
assertThat(serializedTemplate, equalTo(Workflow.WorkflowTemplate.newBuilder().setMetadata(Workflow.WorkflowMetadata.newBuilder().build()).setInterface(Interface.TypedInterface.newBuilder().setOutputs(Interface.VariableMap.newBuilder().build()).setInputs(Interface.VariableMap.newBuilder().build()).build()).addNodes(expectedNode1).addNodes(expectedNode2).build()));
}
use of org.flyte.api.v1.Node in project flytekit-java by flyteorg.
the class ProtoUtilTest method createNode.
private Node createNode(String id) {
String taskName = "task-" + id;
String version = "version-" + id;
String input_name = "input-name-" + id;
String input_scalar = "input-scalar-" + id;
TaskNode taskNode = TaskNode.builder().referenceId(PartialTaskIdentifier.builder().domain(DOMAIN).project(PROJECT).name(taskName).version(version).build()).build();
List<Binding> inputs = singletonList(Binding.builder().var_(input_name).binding(BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofStringValue(input_scalar)))).build());
return Node.builder().id(id).taskNode(taskNode).inputs(inputs).upstreamNodeIds(emptyList()).build();
}
use of org.flyte.api.v1.Node in project flytekit-java by flyteorg.
the class ProtoUtilTest method shouldSerializeOutputReference.
@Test
void shouldSerializeOutputReference() {
OutputReference input = OutputReference.builder().nodeId("node-id").var("var").build();
Types.OutputReference expected = Types.OutputReference.newBuilder().setNodeId("node-id").setVar("var").build();
Types.OutputReference output = ProtoUtil.serialize(input);
assertEquals(expected, output);
}
use of org.flyte.api.v1.Node in project flytekit-java by flyteorg.
the class ProjectClosure method checkCycles.
static boolean checkCycles(WorkflowIdentifier workflowId, Map<WorkflowIdentifier, WorkflowTemplate> allWorkflows, Set<WorkflowIdentifier> beingVisited, Set<WorkflowIdentifier> visited) {
beingVisited.add(workflowId);
WorkflowTemplate workflow = allWorkflows.get(workflowId);
List<Node> nodes = workflow.nodes().stream().flatMap(ProjectClosure::flatBranch).collect(toUnmodifiableList());
for (Node node : nodes) {
if (isSubWorkflowNode(node)) {
PartialWorkflowIdentifier partialSubWorkflowId = Objects.requireNonNull(node.workflowNode()).reference().subWorkflowRef();
WorkflowIdentifier subWorkflowId = WorkflowIdentifier.builder().project(partialSubWorkflowId.project()).name(partialSubWorkflowId.name()).domain(partialSubWorkflowId.domain()).version(partialSubWorkflowId.version()).build();
if (// backward edge
beingVisited.contains(subWorkflowId) || (!visited.contains(subWorkflowId) && checkCycles(subWorkflowId, allWorkflows, beingVisited, visited))) {
return true;
}
}
}
beingVisited.remove(workflowId);
visited.add(workflowId);
return false;
}
use of org.flyte.api.v1.Node in project flytekit-java by flyteorg.
the class IdentifierRewriteTest method shouldRewriteBranchNodes.
@Test
void shouldRewriteBranchNodes() {
ComparisonExpression comparison = ComparisonExpression.builder().operator(ComparisonExpression.Operator.EQ).leftValue(Operand.ofVar("a")).rightValue(Operand.ofVar("b")).build();
BooleanExpression condition = BooleanExpression.ofComparison(comparison);
PartialTaskIdentifier partialReference = PartialTaskIdentifier.builder().name("task-name").build();
PartialTaskIdentifier rewrittenReference = PartialTaskIdentifier.builder().name("task-name").domain("rewritten-domain").version("rewritten-version").project("rewritten-project").build();
TaskNode partialTaskNode = TaskNode.builder().referenceId(partialReference).build();
TaskNode rewrittenTaskNode = TaskNode.builder().referenceId(rewrittenReference).build();
Node partialNode = Node.builder().id("node-1").inputs(ImmutableList.of()).upstreamNodeIds(ImmutableList.of()).taskNode(partialTaskNode).build();
Node rewrittenNode = Node.builder().id("node-1").inputs(ImmutableList.of()).upstreamNodeIds(ImmutableList.of()).taskNode(rewrittenTaskNode).build();
IfBlock partialIfBlock = IfBlock.builder().condition(condition).thenNode(partialNode).build();
IfBlock rewrittenIfBlock = IfBlock.builder().condition(condition).thenNode(rewrittenNode).build();
BranchNode partialBranchNode = BranchNode.builder().ifElse(IfElseBlock.builder().case_(partialIfBlock).other(ImmutableList.of(partialIfBlock)).elseNode(partialNode).build()).build();
BranchNode rewrittenBranchNode = BranchNode.builder().ifElse(IfElseBlock.builder().case_(rewrittenIfBlock).other(ImmutableList.of(rewrittenIfBlock)).elseNode(rewrittenNode).build()).build();
assertThat(rewriter.visitor().visitBranchNode(partialBranchNode), equalTo(rewrittenBranchNode));
}
Aggregations