Search in sources :

Example 31 with RandomCutForest

use of com.amazon.randomcutforest.RandomCutForest in project ml-commons by opensearch-project.

the class BatchRandomCutForest method train.

@Override
public Model train(DataFrame dataFrame) {
    RandomCutForest forest = createRandomCutForest(dataFrame);
    Integer actualTrainingDataSize = trainingDataSize == null ? dataFrame.size() : trainingDataSize;
    process(dataFrame, forest, actualTrainingDataSize);
    Model model = new Model();
    model.setName(FunctionName.BATCH_RCF.name());
    model.setVersion(1);
    RandomCutForestState state = rcfMapper.toState(forest);
    model.setContent(ModelSerDeSer.serialize(state));
    return model;
}
Also used : RandomCutForest(com.amazon.randomcutforest.RandomCutForest) Model(org.opensearch.ml.common.parameter.Model) RandomCutForestState(com.amazon.randomcutforest.state.RandomCutForestState)

Example 32 with RandomCutForest

use of com.amazon.randomcutforest.RandomCutForest in project ml-commons by opensearch-project.

the class BatchRandomCutForest method trainAndPredict.

@Override
public MLOutput trainAndPredict(DataFrame dataFrame) {
    RandomCutForest forest = createRandomCutForest(dataFrame);
    Integer actualTrainingDataSize = trainingDataSize == null ? dataFrame.size() : trainingDataSize;
    List<Map<String, Object>> predictResult = process(dataFrame, forest, actualTrainingDataSize);
    return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build();
}
Also used : RandomCutForest(com.amazon.randomcutforest.RandomCutForest) HashMap(java.util.HashMap) Map(java.util.Map)

Example 33 with RandomCutForest

use of com.amazon.randomcutforest.RandomCutForest in project ml-commons by opensearch-project.

the class BatchRandomCutForest method process.

private List<Map<String, Object>> process(DataFrame dataFrame, RandomCutForest forest, Integer actualTrainingDataSize) {
    List<Double> pointList = new ArrayList<>();
    ColumnMeta[] columnMetas = dataFrame.columnMetas();
    List<Map<String, Object>> predictResult = new ArrayList<>();
    for (int rowNum = 0; rowNum < dataFrame.size(); rowNum++) {
        for (int i = 0; i < columnMetas.length; i++) {
            Row row = dataFrame.getRow(rowNum);
            ColumnValue value = row.getValue(i);
            pointList.add(value.doubleValue());
        }
        double[] point = pointList.stream().mapToDouble(d -> d).toArray();
        pointList.clear();
        double anomalyScore = forest.getAnomalyScore(point);
        if (actualTrainingDataSize == null || rowNum < actualTrainingDataSize) {
            forest.update(point);
        }
        Map<String, Object> result = new HashMap<>();
        result.put("score", anomalyScore);
        result.put("anomalous", anomalyScore > anomalyScoreThreshold);
        predictResult.add(result);
    }
    return predictResult;
}
Also used : MLOutput(org.opensearch.ml.common.parameter.MLOutput) RandomCutForestState(com.amazon.randomcutforest.state.RandomCutForestState) Row(org.opensearch.ml.common.dataframe.Row) ColumnValue(org.opensearch.ml.common.dataframe.ColumnValue) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) Function(org.opensearch.ml.engine.annotation.Function) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) RandomCutForest(com.amazon.randomcutforest.RandomCutForest) List(java.util.List) Model(org.opensearch.ml.common.parameter.Model) FunctionName(org.opensearch.ml.common.parameter.FunctionName) ModelSerDeSer(org.opensearch.ml.engine.utils.ModelSerDeSer) Map(java.util.Map) Log4j2(lombok.extern.log4j.Log4j2) MLAlgoParams(org.opensearch.ml.common.parameter.MLAlgoParams) Optional(java.util.Optional) TrainAndPredictable(org.opensearch.ml.engine.TrainAndPredictable) DataFrameBuilder(org.opensearch.ml.common.dataframe.DataFrameBuilder) RandomCutForestMapper(com.amazon.randomcutforest.state.RandomCutForestMapper) ColumnMeta(org.opensearch.ml.common.dataframe.ColumnMeta) BatchRCFParams(org.opensearch.ml.common.parameter.BatchRCFParams) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) ColumnMeta(org.opensearch.ml.common.dataframe.ColumnMeta) ColumnValue(org.opensearch.ml.common.dataframe.ColumnValue) Row(org.opensearch.ml.common.dataframe.Row) HashMap(java.util.HashMap) Map(java.util.Map)

Aggregations

RandomCutForest (com.amazon.randomcutforest.RandomCutForest)33 Random (java.util.Random)14 RandomCutForestMapper (com.amazon.randomcutforest.state.RandomCutForestMapper)11 Precision (com.amazon.randomcutforest.config.Precision)10 RandomCutForestState (com.amazon.randomcutforest.state.RandomCutForestState)10 Test (org.junit.jupiter.api.Test)10 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)10 NormalMixtureTestData (com.amazon.randomcutforest.testutils.NormalMixtureTestData)7 ThresholdedRandomCutForest (com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest)5 AnomalyDescriptor (com.amazon.randomcutforest.parkservices.AnomalyDescriptor)4 CompactSampler (com.amazon.randomcutforest.sampler.CompactSampler)4 MultiDimDataWithKey (com.amazon.randomcutforest.testutils.MultiDimDataWithKey)4 ArrayList (java.util.ArrayList)4 ComponentList (com.amazon.randomcutforest.ComponentList)3 PointStoreCoordinator (com.amazon.randomcutforest.executor.PointStoreCoordinator)3 CompactSamplerMapper (com.amazon.randomcutforest.state.sampler.CompactSamplerMapper)3 CompactSamplerState (com.amazon.randomcutforest.state.sampler.CompactSamplerState)3 PointStoreMapper (com.amazon.randomcutforest.state.store.PointStoreMapper)3 CompactRandomCutTreeContext (com.amazon.randomcutforest.state.tree.CompactRandomCutTreeContext)3 IPointStore (com.amazon.randomcutforest.store.IPointStore)3