use of org.flyte.api.v1.TaskTemplate in project flytekit-java by flyteorg.
the class ExecuteDynamicWorkflow method execute.
private void execute() {
Config config = Config.load();
ExecutionConfig executionConfig = ExecutionConfig.load();
Collection<ClassLoader> modules = ClassLoaders.forModuleDir(config.moduleDir()).values();
Map<String, FileSystem> fileSystems = FileSystemLoader.loadFileSystems(modules);
FileSystem outputFs = FileSystemLoader.getFileSystem(fileSystems, outputPrefix);
ProtoWriter protoWriter = new ProtoWriter(outputPrefix, outputFs);
try {
FileSystem inputFs = FileSystemLoader.getFileSystem(fileSystems, inputs);
ProtoReader protoReader = new ProtoReader(inputFs);
TaskTemplate taskTemplate = protoReader.getTaskTemplate(taskTemplatePath);
ClassLoader packageClassLoader = PackageLoader.load(fileSystems, taskTemplate);
Map<String, String> env = getEnv();
Map<WorkflowIdentifier, WorkflowTemplate> workflowTemplates = ClassLoaders.withClassLoader(packageClassLoader, () -> Registrars.loadAll(WorkflowTemplateRegistrar.class, env));
Map<TaskIdentifier, RunnableTask> runnableTasks = ClassLoaders.withClassLoader(packageClassLoader, () -> Registrars.loadAll(RunnableTaskRegistrar.class, env));
Map<TaskIdentifier, DynamicWorkflowTask> dynamicWorkflowTasks = ClassLoaders.withClassLoader(packageClassLoader, () -> Registrars.loadAll(DynamicWorkflowTaskRegistrar.class, env));
// before we run anything, switch class loader, otherwise,
// ServiceLoaders and other things wouldn't work, for instance,
// FileSystemRegister in Apache Beam
// we don't take the whole "custom" field, but only jflyte part, for that we ser-de it
Struct custom = JFlyteCustom.deserializeFromStruct(taskTemplate.custom()).serializeToStruct();
// all tasks already have staged jars, we can reuse 'jflyte' custom from current task to get
// it
Map<TaskIdentifier, TaskTemplate> taskTemplates = mapValues(ProjectClosure.createTaskTemplates(executionConfig, runnableTasks, dynamicWorkflowTasks), template -> template.toBuilder().custom(ProjectClosure.merge(template.custom(), custom)).build());
DynamicJobSpec futures = withClassLoader(packageClassLoader, () -> {
Map<String, Literal> input = protoReader.getInput(inputs);
DynamicWorkflowTask task = getDynamicWorkflowTask(this.task);
return task.run(input);
});
DynamicJobSpec rewrittenFutures = rewrite(executionConfig, futures, taskTemplates, workflowTemplates);
if (rewrittenFutures.nodes().isEmpty()) {
Map<String, Literal> outputs = getLiteralMap(rewrittenFutures.outputs());
protoWriter.writeOutputs(outputs);
} else {
protoWriter.writeFutures(rewrittenFutures);
}
} catch (ContainerError e) {
LOG.error("failed to run dynamic workflow", e);
protoWriter.writeError(ProtoUtil.serializeContainerError(e));
} catch (Throwable e) {
LOG.error("failed to run dynamic workflow", e);
protoWriter.writeError(ProtoUtil.serializeThrowable(e));
}
}
use of org.flyte.api.v1.TaskTemplate in project flytekit-java by flyteorg.
the class ProjectClosure method createTaskTemplates.
static Map<TaskIdentifier, TaskTemplate> createTaskTemplates(ExecutionConfig config, Map<TaskIdentifier, RunnableTask> runnableTasks, Map<TaskIdentifier, DynamicWorkflowTask> dynamicWorkflowTasks) {
Map<TaskIdentifier, TaskTemplate> taskTemplates = new HashMap<>();
runnableTasks.forEach((id, task) -> {
TaskTemplate taskTemplate = createTaskTemplateForRunnableTask(task, config.image());
taskTemplates.put(id, taskTemplate);
});
dynamicWorkflowTasks.forEach((id, task) -> {
TaskTemplate taskTemplate = createTaskTemplateForDynamicWorkflow(task, config.image());
taskTemplates.put(id, taskTemplate);
});
return taskTemplates;
}
use of org.flyte.api.v1.TaskTemplate in project flytekit-java by flyteorg.
the class ProtoUtil method serialize.
static Tasks.TaskTemplate serialize(TaskTemplate taskTemplate) {
Tasks.RuntimeMetadata runtime = Tasks.RuntimeMetadata.newBuilder().setType(Tasks.RuntimeMetadata.RuntimeType.FLYTE_SDK).setFlavor(RUNTIME_FLAVOR).setVersion(RUNTIME_VERSION).build();
Tasks.TaskMetadata metadata = Tasks.TaskMetadata.newBuilder().setRuntime(runtime).setRetries(serialize(taskTemplate.retries())).build();
Container container = requireNonNull(taskTemplate.container(), "Only container based task templates are supported");
return Tasks.TaskTemplate.newBuilder().setContainer(serialize(container)).setMetadata(metadata).setInterface(serialize(taskTemplate.interface_())).setType(taskTemplate.type()).setCustom(serializeStruct(taskTemplate.custom())).build();
}
use of org.flyte.api.v1.TaskTemplate in project flytekit-java by flyteorg.
the class ProjectClosureTest method testCreateTaskTemplateForRunnableTask.
@Test
public void testCreateTaskTemplateForRunnableTask() {
// given
RunnableTask task = createRunnableTask(null);
String image = "my-image";
Resources expectedResources = Resources.builder().build();
// when
TaskTemplate result = ProjectClosure.createTaskTemplateForRunnableTask(task, image);
// then
Container container = result.container();
assertNotNull(container);
assertThat(container.image(), equalTo(image));
assertThat(container.resources(), equalTo(expectedResources));
assertThat(result.interface_(), equalTo(TypedInterface.builder().inputs(SdkTypes.nulls().getVariableMap()).outputs(SdkTypes.nulls().getVariableMap()).build()));
assertThat(result.custom(), equalTo(Struct.of(emptyMap())));
assertThat(result.retries(), equalTo(RetryStrategy.builder().retries(0).build()));
assertThat(result.type(), equalTo("java-task"));
}
use of org.flyte.api.v1.TaskTemplate in project flytekit-java by flyteorg.
the class ProtoReaderTest method shouldReadTaskTemplate.
@Test
void shouldReadTaskTemplate() throws IOException {
Tasks.TaskTemplate template = Tasks.TaskTemplate.newBuilder().setType("jflyte").setContainer(Tasks.Container.newBuilder().setImage("image").addCommand("jflyte").addArgs("arg1").addArgs("arg2").build()).build();
Path templatePath = extension.getFileSystem().getPath("/test/template");
writeProto(template, templatePath);
ProtoReader protoReader = new ProtoReader(new InMemoryFileSystem(extension.getFileSystem()));
TaskTemplate actual = protoReader.getTaskTemplate("/test/template");
assertThat(actual, equalTo(TaskTemplate.builder().type("jflyte").container(Container.builder().image("image").command(singletonList("jflyte")).args(asList("arg1", "arg2")).env(emptyList()).build()).interface_(TypedInterface.builder().inputs(emptyMap()).outputs(emptyMap()).build()).retries(RetryStrategy.builder().retries(0).build()).custom(Struct.of(emptyMap())).build()));
}
Aggregations