Search in sources :

Example 21 with DataFrame

use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.

the class MLPredictTaskRunner method startPredictionTask.

/**
 * Start prediction task
 * @param request MLPredictionTaskRequest
 * @param listener Action listener
 */
public void startPredictionTask(MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
    MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
    Instant now = Instant.now();
    MLTask mlTask = MLTask.builder().taskId(UUID.randomUUID().toString()).modelId(request.getModelId()).taskType(MLTaskType.PREDICTION).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(false).build();
    MLInput mlInput = request.getMlInput();
    if (mlInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
        ActionListener<DataFrame> dataFrameActionListener = ActionListener.wrap(dataFrame -> {
            predict(mlTask, dataFrame, request, listener);
        }, e -> {
            log.error("Failed to generate DataFrame from search query", e);
            handleAsyncMLTaskFailure(mlTask, e);
            listener.onFailure(e);
        });
        mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false));
    } else {
        DataFrame inputDataFrame = mlInputDatasetHandler.parseDataFrameInput(mlInput.getInputDataset());
        threadPool.executor(TASK_THREAD_POOL).execute(() -> {
            predict(mlTask, inputDataFrame, request, listener);
        });
    }
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Instant(java.time.Instant) MLInputDataType(org.opensearch.ml.common.dataset.MLInputDataType) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) MLTask(org.opensearch.ml.common.parameter.MLTask)

Example 22 with DataFrame

use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.

the class MLEngineTest method predictLinearRegression.

@Test
public void predictLinearRegression() {
    Model model = trainLinearRegressionModel();
    DataFrame predictionDataFrame = constructLinearRegressionPredictionDataFrame();
    MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build();
    Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).inputDataset(inputDataset).build();
    MLPredictionOutput output = (MLPredictionOutput) MLEngine.predict(mlInput, model);
    DataFrame predictions = output.getPredictionResult();
    Assert.assertEquals(2, predictions.size());
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) Model(org.opensearch.ml.common.parameter.Model) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) LinearRegressionHelper.constructLinearRegressionPredictionDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) LinearRegressionHelper.constructLinearRegressionTrainDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) Test(org.junit.Test)

Example 23 with DataFrame

use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.

the class MLEngineTest method trainAndPredictWithKmeans.

@Test
public void trainAndPredictWithKmeans() {
    int dataSize = 100;
    MLAlgoParams parameters = KMeansParams.builder().build();
    DataFrame dataFrame = constructKMeansDataFrame(dataSize);
    MLInputDataset inputData = new DataFrameInputDataset(dataFrame);
    Input input = new MLInput(FunctionName.KMEANS, parameters, inputData);
    MLPredictionOutput output = (MLPredictionOutput) MLEngine.trainAndPredict(input);
    Assert.assertEquals(dataSize, output.getPredictionResult().size());
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) MLInput(org.opensearch.ml.common.parameter.MLInput) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) LinearRegressionHelper.constructLinearRegressionPredictionDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) LinearRegressionHelper.constructLinearRegressionTrainDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) MLAlgoParams(org.opensearch.ml.common.parameter.MLAlgoParams) Test(org.junit.Test)

Example 24 with DataFrame

use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.

the class MLEngineTest method trainLinearRegressionModel.

private Model trainLinearRegressionModel() {
    LinearRegressionParams parameters = LinearRegressionParams.builder().objectiveType(LinearRegressionParams.ObjectiveType.SQUARED_LOSS).optimizerType(LinearRegressionParams.OptimizerType.ADAM).learningRate(0.01).epsilon(1e-6).beta1(0.9).beta2(0.99).target("price").build();
    DataFrame trainDataFrame = constructLinearRegressionTrainDataFrame();
    MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(trainDataFrame).build();
    Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).parameters(parameters).inputDataset(inputDataset).build();
    return MLEngine.train(mlInput);
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) LinearRegressionParams(org.opensearch.ml.common.parameter.LinearRegressionParams) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) LinearRegressionHelper.constructLinearRegressionPredictionDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) LinearRegressionHelper.constructLinearRegressionTrainDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame) DataFrame(org.opensearch.ml.common.dataframe.DataFrame)

Example 25 with DataFrame

use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.

the class MLEngineTest method predictKMeans.

@Test
public void predictKMeans() {
    Model model = trainKMeansModel();
    DataFrame predictionDataFrame = constructKMeansDataFrame(10);
    MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build();
    Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
    MLPredictionOutput output = (MLPredictionOutput) MLEngine.predict(mlInput, model);
    DataFrame predictions = output.getPredictionResult();
    Assert.assertEquals(10, predictions.size());
    predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) Model(org.opensearch.ml.common.parameter.Model) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) LinearRegressionHelper.constructLinearRegressionPredictionDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) LinearRegressionHelper.constructLinearRegressionTrainDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) Test(org.junit.Test)

Aggregations

DataFrame (org.opensearch.ml.common.dataframe.DataFrame)34 ColumnMeta (org.opensearch.ml.common.dataframe.ColumnMeta)10 DefaultDataFrame (org.opensearch.ml.common.dataframe.DefaultDataFrame)10 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)10 MLInput (org.opensearch.ml.common.parameter.MLInput)9 ArrayList (java.util.ArrayList)8 Test (org.junit.Test)8 Model (org.opensearch.ml.common.parameter.Model)8 Row (org.opensearch.ml.common.dataframe.Row)7 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)7 MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)7 KMeansHelper.constructKMeansDataFrame (org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame)7 HashMap (java.util.HashMap)6 ColumnValue (org.opensearch.ml.common.dataframe.ColumnValue)6 LinearRegressionHelper.constructLinearRegressionPredictionDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame)5 LinearRegressionHelper.constructLinearRegressionTrainDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame)5 List (java.util.List)4 Before (org.junit.Before)4 Input (org.opensearch.ml.common.parameter.Input)4 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)4