use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.
the class MLEngineTest method predictLinearRegression.
@Test
public void predictLinearRegression() {
Model model = trainLinearRegressionModel();
DataFrame predictionDataFrame = constructLinearRegressionPredictionDataFrame();
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build();
Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).inputDataset(inputDataset).build();
MLPredictionOutput output = (MLPredictionOutput) MLEngine.predict(mlInput, model);
DataFrame predictions = output.getPredictionResult();
Assert.assertEquals(2, predictions.size());
}
use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.
the class MLEngineTest method trainAndPredictWithKmeans.
@Test
public void trainAndPredictWithKmeans() {
int dataSize = 100;
MLAlgoParams parameters = KMeansParams.builder().build();
DataFrame dataFrame = constructKMeansDataFrame(dataSize);
MLInputDataset inputData = new DataFrameInputDataset(dataFrame);
Input input = new MLInput(FunctionName.KMEANS, parameters, inputData);
MLPredictionOutput output = (MLPredictionOutput) MLEngine.trainAndPredict(input);
Assert.assertEquals(dataSize, output.getPredictionResult().size());
}
use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.
the class MLEngineTest method predictKMeans.
@Test
public void predictKMeans() {
Model model = trainKMeansModel();
DataFrame predictionDataFrame = constructKMeansDataFrame(10);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build();
Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
MLPredictionOutput output = (MLPredictionOutput) MLEngine.predict(mlInput, model);
DataFrame predictions = output.getPredictionResult();
Assert.assertEquals(10, predictions.size());
predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
}
use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.
the class AnomalyDetectionLibSVMTest method predict.
@Test
public void predict() {
Model model = anomalyDetection.train(trainDataFrame);
MLPredictionOutput output = (MLPredictionOutput) anomalyDetection.predict(predictionDataFrame, model);
DataFrame predictions = output.getPredictionResult();
int i = 0;
int truePositive = 0;
int falsePositive = 0;
int totalPositive = 0;
for (Row row : predictions) {
String type = row.getValue(1).stringValue();
if (predictionLabels.get(i) == Event.EventType.ANOMALOUS) {
totalPositive++;
if ("ANOMALOUS".equals(type)) {
truePositive++;
}
} else if ("ANOMALOUS".equals(type)) {
falsePositive++;
}
i++;
}
float precision = (float) truePositive / (truePositive + falsePositive);
float recall = (float) truePositive / totalPositive;
Assert.assertEquals(0.7, precision, 0.01);
Assert.assertEquals(1.0, recall, 0.01);
}
use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.
the class FixedInTimeRandomCutForestTest method predict.
@Test
public void predict() {
Model model = forest.train(trainDataFrame);
MLPredictionOutput output = (MLPredictionOutput) forest.predict(predictionDataFrame, model);
DataFrame predictions = output.getPredictionResult();
Assert.assertEquals(dataSize, predictions.size());
int anomalyCount = 0;
for (int i = 0; i < dataSize; i++) {
if (i % 100 == 0) {
if (predictions.getRow(i).getValue(1).doubleValue() > 0.01) {
anomalyCount++;
}
}
}
// total anomalies 5
Assert.assertTrue("Fewer anomaly detected: " + anomalyCount, anomalyCount > 1);
}
Aggregations