use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class AnomalyDetectionLibSVMTest method predict.
@Test
public void predict() {
Model model = anomalyDetection.train(trainDataFrame);
MLPredictionOutput output = (MLPredictionOutput) anomalyDetection.predict(predictionDataFrame, model);
DataFrame predictions = output.getPredictionResult();
int i = 0;
int truePositive = 0;
int falsePositive = 0;
int totalPositive = 0;
for (Row row : predictions) {
String type = row.getValue(1).stringValue();
if (predictionLabels.get(i) == Event.EventType.ANOMALOUS) {
totalPositive++;
if ("ANOMALOUS".equals(type)) {
truePositive++;
}
} else if ("ANOMALOUS".equals(type)) {
falsePositive++;
}
i++;
}
float precision = (float) truePositive / (truePositive + falsePositive);
float recall = (float) truePositive / totalPositive;
Assert.assertEquals(0.7, precision, 0.01);
Assert.assertEquals(1.0, recall, 0.01);
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class FixedInTimeRandomCutForestTest method predict.
@Test
public void predict() {
Model model = forest.train(trainDataFrame);
MLPredictionOutput output = (MLPredictionOutput) forest.predict(predictionDataFrame, model);
DataFrame predictions = output.getPredictionResult();
Assert.assertEquals(dataSize, predictions.size());
int anomalyCount = 0;
for (int i = 0; i < dataSize; i++) {
if (i % 100 == 0) {
if (predictions.getRow(i).getValue(1).doubleValue() > 0.01) {
anomalyCount++;
}
}
}
// total anomalies 5
Assert.assertTrue("Fewer anomaly detected: " + anomalyCount, anomalyCount > 1);
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class BatchRandomCutForest method process.
private List<Map<String, Object>> process(DataFrame dataFrame, RandomCutForest forest, Integer actualTrainingDataSize) {
List<Double> pointList = new ArrayList<>();
ColumnMeta[] columnMetas = dataFrame.columnMetas();
List<Map<String, Object>> predictResult = new ArrayList<>();
for (int rowNum = 0; rowNum < dataFrame.size(); rowNum++) {
for (int i = 0; i < columnMetas.length; i++) {
Row row = dataFrame.getRow(rowNum);
ColumnValue value = row.getValue(i);
pointList.add(value.doubleValue());
}
double[] point = pointList.stream().mapToDouble(d -> d).toArray();
pointList.clear();
double anomalyScore = forest.getAnomalyScore(point);
if (actualTrainingDataSize == null || rowNum < actualTrainingDataSize) {
forest.update(point);
}
Map<String, Object> result = new HashMap<>();
result.put("score", anomalyScore);
result.put("anomalous", anomalyScore > anomalyScoreThreshold);
predictResult.add(result);
}
return predictResult;
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class BatchRandomCutForestTest method constructRCFDataFrame.
private DataFrame constructRCFDataFrame(boolean predict) {
ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("value", ColumnType.INTEGER) };
DataFrame dataFrame = new DefaultDataFrame(columnMetas);
for (int i = 0; i < dataSize; i++) {
if (predict && i % 100 == 0) {
dataFrame.appendRow(new Object[] { ThreadLocalRandom.current().nextInt(100, 1000) });
} else {
dataFrame.appendRow(new Object[] { ThreadLocalRandom.current().nextInt(1, 10) });
}
}
return dataFrame;
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class BatchRandomCutForestTest method verifyPredictionResult.
private void verifyPredictionResult(MLPredictionOutput output) {
DataFrame predictions = output.getPredictionResult();
Assert.assertEquals(dataSize, predictions.size());
int anomalyCount = 0;
for (int i = 0; i < dataSize; i++) {
Row row = predictions.getRow(i);
if (i % 100 == 0) {
if (row.getValue(0).doubleValue() > 0.01) {
anomalyCount++;
}
}
}
// total anomalies 5
Assert.assertTrue("Fewer anomaly detected: " + anomalyCount, anomalyCount > 1);
}
Aggregations