use of org.flyte.api.v1.TaskIdentifier in project flytekit-java by flyteorg.
the class SdkRunnableTaskRegistrarTest method shouldLoadRunnableTasksFromDiscoveredRegistries.
@Test
void shouldLoadRunnableTasksFromDiscoveredRegistries() {
// given
String testTaskName = "org.flyte.flytekit.SdkRunnableTaskRegistrarTest$TestTask";
String otherTestTaskName = "org.flyte.flytekit.SdkRunnableTaskRegistrarTest$OtherTestTask";
TaskIdentifier expectedTestTaskId = TaskIdentifier.builder().project("project").domain("domain").name(testTaskName).version("version").build();
TypedInterface typedInterface = TypedInterface.builder().inputs(SdkTypes.nulls().getVariableMap()).outputs(SdkTypes.nulls().getVariableMap()).build();
RetryStrategy retries = RetryStrategy.builder().retries(0).build();
RetryStrategy otherRetries = RetryStrategy.builder().retries(1).build();
Map<Resources.ResourceName, String> limits = new HashMap<>();
limits.put(Resources.ResourceName.CPU, "0.5");
limits.put(Resources.ResourceName.MEMORY, "2Gi");
Map<Resources.ResourceName, String> requests = new HashMap<>();
requests.put(Resources.ResourceName.CPU, "2");
requests.put(Resources.ResourceName.MEMORY, "5Gi");
Resources resources = Resources.builder().limits(limits).requests(requests).build();
RunnableTask expectedTask = createRunnableTask(testTaskName, typedInterface, retries, null);
TaskIdentifier expectedOtherTestTaskId = TaskIdentifier.builder().project("project").domain("domain").name(otherTestTaskName).version("version").build();
RunnableTask expectedOtherTask = createRunnableTask(otherTestTaskName, typedInterface, otherRetries, resources);
// when
Map<TaskIdentifier, RunnableTask> tasks = registrar.load(ENV);
// then
assertAll(() -> assertThat(tasks, hasKey(is(expectedTestTaskId))), () -> assertThat(tasks, hasKey(is(expectedOtherTestTaskId))));
assertTaskEquals(tasks.get(expectedTestTaskId), expectedTask);
assertTaskEquals(tasks.get(expectedOtherTestTaskId), expectedOtherTask);
}
Aggregations