use of com.amazon.randomcutforest.RandomCutForest in project ml-commons by opensearch-project.
the class BatchRandomCutForest method train.
@Override
public Model train(DataFrame dataFrame) {
RandomCutForest forest = createRandomCutForest(dataFrame);
Integer actualTrainingDataSize = trainingDataSize == null ? dataFrame.size() : trainingDataSize;
process(dataFrame, forest, actualTrainingDataSize);
Model model = new Model();
model.setName(FunctionName.BATCH_RCF.name());
model.setVersion(1);
RandomCutForestState state = rcfMapper.toState(forest);
model.setContent(ModelSerDeSer.serialize(state));
return model;
}
use of com.amazon.randomcutforest.RandomCutForest in project ml-commons by opensearch-project.
the class BatchRandomCutForest method trainAndPredict.
@Override
public MLOutput trainAndPredict(DataFrame dataFrame) {
RandomCutForest forest = createRandomCutForest(dataFrame);
Integer actualTrainingDataSize = trainingDataSize == null ? dataFrame.size() : trainingDataSize;
List<Map<String, Object>> predictResult = process(dataFrame, forest, actualTrainingDataSize);
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build();
}
use of com.amazon.randomcutforest.RandomCutForest in project ml-commons by opensearch-project.
the class BatchRandomCutForest method process.
private List<Map<String, Object>> process(DataFrame dataFrame, RandomCutForest forest, Integer actualTrainingDataSize) {
List<Double> pointList = new ArrayList<>();
ColumnMeta[] columnMetas = dataFrame.columnMetas();
List<Map<String, Object>> predictResult = new ArrayList<>();
for (int rowNum = 0; rowNum < dataFrame.size(); rowNum++) {
for (int i = 0; i < columnMetas.length; i++) {
Row row = dataFrame.getRow(rowNum);
ColumnValue value = row.getValue(i);
pointList.add(value.doubleValue());
}
double[] point = pointList.stream().mapToDouble(d -> d).toArray();
pointList.clear();
double anomalyScore = forest.getAnomalyScore(point);
if (actualTrainingDataSize == null || rowNum < actualTrainingDataSize) {
forest.update(point);
}
Map<String, Object> result = new HashMap<>();
result.put("score", anomalyScore);
result.put("anomalous", anomalyScore > anomalyScoreThreshold);
predictResult.add(result);
}
return predictResult;
}
Aggregations