use of org.opensearch.ml.common.parameter.MLTask in project ml-commons by opensearch-project.
the class MLTrainAndPredictTaskRunner method startTrainAndPredictionTask.
/**
* Start prediction task
* @param request MLPredictionTaskRequest
* @param listener Action listener
*/
public void startTrainAndPredictionTask(MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
Instant now = Instant.now();
MLTask mlTask = MLTask.builder().taskId(UUID.randomUUID().toString()).taskType(MLTaskType.TRAINING_AND_PREDICTION).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(false).build();
MLInput mlInput = request.getMlInput();
if (mlInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
ActionListener<DataFrame> dataFrameActionListener = ActionListener.wrap(dataFrame -> {
trainAndPredict(mlTask, dataFrame, request, listener);
}, e -> {
log.error("Failed to generate DataFrame from search query", e);
handlePredictFailure(mlTask, listener, e, false);
});
mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false));
} else {
DataFrame inputDataFrame = mlInputDatasetHandler.parseDataFrameInput(mlInput.getInputDataset());
threadPool.executor(TASK_THREAD_POOL).execute(() -> {
trainAndPredict(mlTask, inputDataFrame, request, listener);
});
}
}
Aggregations