Search in sources :

Example 1 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class MLTrainAndPredictTaskRunner method trainAndPredict.

private void trainAndPredict(MLTask mlTask, DataFrame inputDataFrame, MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
    ActionListener<MLTaskResponse> internalListener = wrappedCleanupListener(listener, mlTask.getTaskId());
    // track ML task count and add ML task into cache
    mlStats.getStat(ML_EXECUTING_TASK_COUNT).increment();
    mlStats.getStat(ML_TOTAL_REQUEST_COUNT).increment();
    mlStats.createCounterStatIfAbsent(requestCountStat(mlTask.getFunctionName(), ActionName.TRAIN_PREDICT)).increment();
    mlTaskManager.add(mlTask);
    MLInput mlInput = request.getMlInput();
    // run train and predict
    try {
        mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
        MLOutput output = MLEngine.trainAndPredict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build());
        handleAsyncMLTaskComplete(mlTask);
        if (output instanceof MLPredictionOutput) {
            ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
        }
        MLTaskResponse response = MLTaskResponse.builder().output(output).build();
        log.info("Train and predict task done for algorithm: {}, task id: {}", mlTask.getFunctionName(), mlTask.getTaskId());
        internalListener.onResponse(response);
    } catch (Exception e) {
        // todo need to specify what exception
        log.error("Failed to train and predict " + mlInput.getAlgorithm(), e);
        handlePredictFailure(mlTask, listener, e, true);
        return;
    }
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInput(org.opensearch.ml.common.parameter.MLInput) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) MLOutput(org.opensearch.ml.common.parameter.MLOutput) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput)

Example 2 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class PredictionITTests method testPredictionWithoutAlgorithm.

public void testPredictionWithoutAlgorithm() throws IOException {
    MLInput mlInput = MLInput.builder().inputDataset(DATA_FRAME_INPUT_DATASET).build();
    MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(taskId, mlInput);
    ActionFuture<MLTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest);
    expectThrows(ActionRequestValidationException.class, () -> predictionFuture.actionGet());
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInput(org.opensearch.ml.common.parameter.MLInput) MLPredictionTaskRequest(org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest)

Example 3 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class PredictionITTests method testPredictionWithoutModelId.

public void testPredictionWithoutModelId() throws IOException {
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(DATA_FRAME_INPUT_DATASET).build();
    MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest("", mlInput);
    ActionFuture<MLTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest);
    expectThrows(ResourceNotFoundException.class, () -> predictionFuture.actionGet());
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInput(org.opensearch.ml.common.parameter.MLInput) MLPredictionTaskRequest(org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest)

Example 4 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class PredictionITTests method testPredictionWithEmptyDataset.

public void testPredictionWithEmptyDataset() throws IOException {
    MLInputDataset emptySearchInputDataset = generateEmptyDataset();
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(emptySearchInputDataset).build();
    MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(taskId, mlInput);
    ActionFuture<MLTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest);
    expectThrows(IllegalArgumentException.class, () -> predictionFuture.actionGet());
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInput(org.opensearch.ml.common.parameter.MLInput) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) MLPredictionTaskRequest(org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest)

Example 5 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class PredictionITTests method testPredictionWithoutDataset.

public void testPredictionWithoutDataset() throws IOException {
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).build();
    MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(taskId, mlInput);
    ActionFuture<MLTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest);
    expectThrows(ActionRequestValidationException.class, () -> predictionFuture.actionGet());
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInput(org.opensearch.ml.common.parameter.MLInput) MLPredictionTaskRequest(org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest)

Aggregations

MLInput (org.opensearch.ml.common.parameter.MLInput)46 Test (org.junit.Test)18 MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)13 MLTaskResponse (org.opensearch.ml.common.transport.MLTaskResponse)12 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)11 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)10 Input (org.opensearch.ml.common.parameter.Input)9 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)9 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)7 MLPredictionTaskRequest (org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest)7 MLTrainingTaskRequest (org.opensearch.ml.common.transport.training.MLTrainingTaskRequest)7 MLOutput (org.opensearch.ml.common.parameter.MLOutput)6 XContentParser (org.opensearch.common.xcontent.XContentParser)5 Response (org.opensearch.client.Response)4 Model (org.opensearch.ml.common.parameter.Model)4 KMeansHelper.constructKMeansDataFrame (org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame)4 LinearRegressionHelper.constructLinearRegressionPredictionDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame)4 LinearRegressionHelper.constructLinearRegressionTrainDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame)4 VisibleForTesting (com.google.common.annotations.VisibleForTesting)3 Instant (java.time.Instant)3