use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class MLTrainAndPredictTaskRunner method trainAndPredict.
private void trainAndPredict(MLTask mlTask, DataFrame inputDataFrame, MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
ActionListener<MLTaskResponse> internalListener = wrappedCleanupListener(listener, mlTask.getTaskId());
// track ML task count and add ML task into cache
mlStats.getStat(ML_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(ML_TOTAL_REQUEST_COUNT).increment();
mlStats.createCounterStatIfAbsent(requestCountStat(mlTask.getFunctionName(), ActionName.TRAIN_PREDICT)).increment();
mlTaskManager.add(mlTask);
MLInput mlInput = request.getMlInput();
// run train and predict
try {
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
MLOutput output = MLEngine.trainAndPredict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build());
handleAsyncMLTaskComplete(mlTask);
if (output instanceof MLPredictionOutput) {
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
}
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
log.info("Train and predict task done for algorithm: {}, task id: {}", mlTask.getFunctionName(), mlTask.getTaskId());
internalListener.onResponse(response);
} catch (Exception e) {
// todo need to specify what exception
log.error("Failed to train and predict " + mlInput.getAlgorithm(), e);
handlePredictFailure(mlTask, listener, e, true);
return;
}
}
use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class PredictionITTests method testPredictionWithoutAlgorithm.
public void testPredictionWithoutAlgorithm() throws IOException {
MLInput mlInput = MLInput.builder().inputDataset(DATA_FRAME_INPUT_DATASET).build();
MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(taskId, mlInput);
ActionFuture<MLTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest);
expectThrows(ActionRequestValidationException.class, () -> predictionFuture.actionGet());
}
use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class PredictionITTests method testPredictionWithoutModelId.
public void testPredictionWithoutModelId() throws IOException {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(DATA_FRAME_INPUT_DATASET).build();
MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest("", mlInput);
ActionFuture<MLTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest);
expectThrows(ResourceNotFoundException.class, () -> predictionFuture.actionGet());
}
use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class PredictionITTests method testPredictionWithEmptyDataset.
public void testPredictionWithEmptyDataset() throws IOException {
MLInputDataset emptySearchInputDataset = generateEmptyDataset();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(emptySearchInputDataset).build();
MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(taskId, mlInput);
ActionFuture<MLTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest);
expectThrows(IllegalArgumentException.class, () -> predictionFuture.actionGet());
}
use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class PredictionITTests method testPredictionWithoutDataset.
public void testPredictionWithoutDataset() throws IOException {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).build();
MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(taskId, mlInput);
ActionFuture<MLTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest);
expectThrows(ActionRequestValidationException.class, () -> predictionFuture.actionGet());
}
Aggregations