Search in sources :

Example 11 with MLPredictionOutput

use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.

the class MLEngineTest method predictLinearRegression.

@Test
public void predictLinearRegression() {
    Model model = trainLinearRegressionModel();
    DataFrame predictionDataFrame = constructLinearRegressionPredictionDataFrame();
    MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build();
    Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).inputDataset(inputDataset).build();
    MLPredictionOutput output = (MLPredictionOutput) MLEngine.predict(mlInput, model);
    DataFrame predictions = output.getPredictionResult();
    Assert.assertEquals(2, predictions.size());
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) Model(org.opensearch.ml.common.parameter.Model) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) LinearRegressionHelper.constructLinearRegressionPredictionDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) LinearRegressionHelper.constructLinearRegressionTrainDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) Test(org.junit.Test)

Example 12 with MLPredictionOutput

use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.

the class MLEngineTest method trainAndPredictWithKmeans.

@Test
public void trainAndPredictWithKmeans() {
    int dataSize = 100;
    MLAlgoParams parameters = KMeansParams.builder().build();
    DataFrame dataFrame = constructKMeansDataFrame(dataSize);
    MLInputDataset inputData = new DataFrameInputDataset(dataFrame);
    Input input = new MLInput(FunctionName.KMEANS, parameters, inputData);
    MLPredictionOutput output = (MLPredictionOutput) MLEngine.trainAndPredict(input);
    Assert.assertEquals(dataSize, output.getPredictionResult().size());
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) MLInput(org.opensearch.ml.common.parameter.MLInput) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) LinearRegressionHelper.constructLinearRegressionPredictionDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) LinearRegressionHelper.constructLinearRegressionTrainDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) MLAlgoParams(org.opensearch.ml.common.parameter.MLAlgoParams) Test(org.junit.Test)

Example 13 with MLPredictionOutput

use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.

the class MLEngineTest method predictKMeans.

@Test
public void predictKMeans() {
    Model model = trainKMeansModel();
    DataFrame predictionDataFrame = constructKMeansDataFrame(10);
    MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build();
    Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
    MLPredictionOutput output = (MLPredictionOutput) MLEngine.predict(mlInput, model);
    DataFrame predictions = output.getPredictionResult();
    Assert.assertEquals(10, predictions.size());
    predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) Model(org.opensearch.ml.common.parameter.Model) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) LinearRegressionHelper.constructLinearRegressionPredictionDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) LinearRegressionHelper.constructLinearRegressionTrainDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) Test(org.junit.Test)

Example 14 with MLPredictionOutput

use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.

the class AnomalyDetectionLibSVMTest method predict.

@Test
public void predict() {
    Model model = anomalyDetection.train(trainDataFrame);
    MLPredictionOutput output = (MLPredictionOutput) anomalyDetection.predict(predictionDataFrame, model);
    DataFrame predictions = output.getPredictionResult();
    int i = 0;
    int truePositive = 0;
    int falsePositive = 0;
    int totalPositive = 0;
    for (Row row : predictions) {
        String type = row.getValue(1).stringValue();
        if (predictionLabels.get(i) == Event.EventType.ANOMALOUS) {
            totalPositive++;
            if ("ANOMALOUS".equals(type)) {
                truePositive++;
            }
        } else if ("ANOMALOUS".equals(type)) {
            falsePositive++;
        }
        i++;
    }
    float precision = (float) truePositive / (truePositive + falsePositive);
    float recall = (float) truePositive / totalPositive;
    Assert.assertEquals(0.7, precision, 0.01);
    Assert.assertEquals(1.0, recall, 0.01);
}
Also used : Model(org.opensearch.ml.common.parameter.Model) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) DefaultDataFrame(org.opensearch.ml.common.dataframe.DefaultDataFrame) Row(org.opensearch.ml.common.dataframe.Row) Test(org.junit.Test)

Example 15 with MLPredictionOutput

use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.

the class FixedInTimeRandomCutForestTest method predict.

@Test
public void predict() {
    Model model = forest.train(trainDataFrame);
    MLPredictionOutput output = (MLPredictionOutput) forest.predict(predictionDataFrame, model);
    DataFrame predictions = output.getPredictionResult();
    Assert.assertEquals(dataSize, predictions.size());
    int anomalyCount = 0;
    for (int i = 0; i < dataSize; i++) {
        if (i % 100 == 0) {
            if (predictions.getRow(i).getValue(1).doubleValue() > 0.01) {
                anomalyCount++;
            }
        }
    }
    // total anomalies 5
    Assert.assertTrue("Fewer anomaly detected: " + anomalyCount, anomalyCount > 1);
}
Also used : Model(org.opensearch.ml.common.parameter.Model) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) DefaultDataFrame(org.opensearch.ml.common.dataframe.DefaultDataFrame) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) Test(org.junit.Test)

Aggregations

MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)19 Test (org.junit.Test)16 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)8 Model (org.opensearch.ml.common.parameter.Model)8 MLTaskResponse (org.opensearch.ml.common.transport.MLTaskResponse)8 MLInput (org.opensearch.ml.common.parameter.MLInput)7 KMeansHelper.constructKMeansDataFrame (org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame)5 HashMap (java.util.HashMap)4 LinearRegressionHelper.constructLinearRegressionPredictionDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame)4 LinearRegressionHelper.constructLinearRegressionTrainDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame)4 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)3 MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)3 Input (org.opensearch.ml.common.parameter.Input)3 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)3 MLOutput (org.opensearch.ml.common.parameter.MLOutput)3 BytesStreamOutput (org.opensearch.common.io.stream.BytesStreamOutput)2 XContentBuilder (org.opensearch.common.xcontent.XContentBuilder)2 DefaultDataFrame (org.opensearch.ml.common.dataframe.DefaultDataFrame)2 MLPredictionTaskRequest (org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest)2 OpenSearchException (org.opensearch.OpenSearchException)1