Search in sources :

Example 6 with RandomCutForestState

use of com.amazon.randomcutforest.state.RandomCutForestState in project random-cut-forest-by-aws by aws.

the class V1JsonToV3StateConverter method convert.

/**
 * the function merges a collection of RCF-1.0 models with same model parameters
 * and fixes the number of trees in the new model (which has to be less or equal
 * than the sum of the old models) The conversion uses the execution context of
 * the first forest and can be adjusted subsequently by setters
 *
 * @param serializedForests A non-empty list of forests (together having more
 *                          trees than numberOfTrees)
 * @param numberOfTrees     the new number of trees
 * @param precision         the precision of the new forest
 * @return a merged RCF with the first numberOfTrees trees
 */
public RandomCutForestState convert(List<V1SerializedRandomCutForest> serializedForests, int numberOfTrees, Precision precision) {
    checkArgument(serializedForests.size() > 0, "incorrect usage of convert");
    checkArgument(numberOfTrees > 0, "incorrect parameter");
    int sum = 0;
    for (int i = 0; i < serializedForests.size(); i++) {
        sum += serializedForests.get(i).getNumberOfTrees();
    }
    checkArgument(sum >= numberOfTrees, "incorrect parameters");
    RandomCutForestState state = new RandomCutForestState();
    state.setNumberOfTrees(numberOfTrees);
    state.setDimensions(serializedForests.get(0).getDimensions());
    state.setTimeDecay(serializedForests.get(0).getLambda());
    state.setSampleSize(serializedForests.get(0).getSampleSize());
    state.setShingleSize(1);
    state.setCenterOfMassEnabled(serializedForests.get(0).isCenterOfMassEnabled());
    state.setOutputAfter(serializedForests.get(0).getOutputAfter());
    state.setStoreSequenceIndexesEnabled(serializedForests.get(0).isStoreSequenceIndexesEnabled());
    state.setTotalUpdates(serializedForests.get(0).getExecutor().getExecutor().getTotalUpdates());
    state.setCompact(true);
    state.setInternalShinglingEnabled(false);
    state.setBoundingBoxCacheFraction(1.0);
    state.setSaveSamplerStateEnabled(true);
    state.setSaveTreeStateEnabled(false);
    state.setSaveCoordinatorStateEnabled(true);
    state.setPrecision(precision.name());
    state.setCompressed(false);
    state.setPartialTreeState(false);
    ExecutionContext executionContext = new ExecutionContext();
    executionContext.setParallelExecutionEnabled(serializedForests.get(0).isParallelExecutionEnabled());
    executionContext.setThreadPoolSize(serializedForests.get(0).getThreadPoolSize());
    state.setExecutionContext(executionContext);
    SamplerConverter samplerConverter = new SamplerConverter(state.getDimensions(), state.getNumberOfTrees() * state.getSampleSize() + 1, precision, numberOfTrees);
    serializedForests.stream().flatMap(f -> Arrays.stream(f.getExecutor().getExecutor().getTreeUpdaters())).limit(numberOfTrees).map(V1SerializedRandomCutForest.TreeUpdater::getSampler).forEach(samplerConverter::addSampler);
    state.setPointStoreState(samplerConverter.getPointStoreState(precision));
    state.setCompactSamplerStates(samplerConverter.compactSamplerStates);
    return state;
}
Also used : Arrays(java.util.Arrays) RandomCutForestState(com.amazon.randomcutforest.state.RandomCutForestState) ExecutionContext(com.amazon.randomcutforest.state.ExecutionContext) CompactSamplerState(com.amazon.randomcutforest.state.sampler.CompactSamplerState) URL(java.net.URL) Precision(com.amazon.randomcutforest.config.Precision) CommonUtils.checkArgument(com.amazon.randomcutforest.CommonUtils.checkArgument) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) IOException(java.io.IOException) Reader(java.io.Reader) PointStoreState(com.amazon.randomcutforest.state.store.PointStoreState) RandomCutTree(com.amazon.randomcutforest.tree.RandomCutTree) ArrayList(java.util.ArrayList) PointStore(com.amazon.randomcutforest.store.PointStore) ITree(com.amazon.randomcutforest.tree.ITree) List(java.util.List) Optional(java.util.Optional) PointStoreMapper(com.amazon.randomcutforest.state.store.PointStoreMapper) IPointStore(com.amazon.randomcutforest.store.IPointStore) Collections(java.util.Collections) ExecutionContext(com.amazon.randomcutforest.state.ExecutionContext) RandomCutForestState(com.amazon.randomcutforest.state.RandomCutForestState)

Example 7 with RandomCutForestState

use of com.amazon.randomcutforest.state.RandomCutForestState in project random-cut-forest-by-aws by aws.

the class V1JsonToV3StateConverterTest method testConvert.

@ParameterizedTest
@MethodSource("args")
public void testConvert(V1JsonResource jsonResource, Precision precision) {
    String resource = jsonResource.getResource();
    try (InputStream is = V1JsonToV3StateConverterTest.class.getResourceAsStream(jsonResource.getResource());
        BufferedReader rr = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) {
        StringBuilder b = new StringBuilder();
        String line;
        while ((line = rr.readLine()) != null) {
            b.append(line);
        }
        String json = b.toString();
        RandomCutForestState state = converter.convert(json, precision);
        assertEquals(jsonResource.getDimensions(), state.getDimensions());
        assertEquals(jsonResource.getNumberOfTrees(), state.getNumberOfTrees());
        assertEquals(jsonResource.getSampleSize(), state.getSampleSize());
        RandomCutForest forest = new RandomCutForestMapper().toModel(state, 0);
        assertEquals(jsonResource.getDimensions(), forest.getDimensions());
        assertEquals(jsonResource.getNumberOfTrees(), forest.getNumberOfTrees());
        assertEquals(jsonResource.getSampleSize(), forest.getSampleSize());
        // perform a simple validation of the deserialized forest by update and scoring
        // with a few points
        Random random = new Random(0);
        for (int i = 0; i < 100; i++) {
            double[] point = getPoint(jsonResource.getDimensions(), random);
            double score = forest.getAnomalyScore(point);
            assertTrue(score > 0);
            forest.update(point);
        }
        String newString = new ObjectMapper().writeValueAsString(new RandomCutForestMapper().toState(forest));
        System.out.println(" Old size " + json.length() + ", new Size " + newString.length() + ", improvement factor " + json.length() / newString.length());
    } catch (IOException e) {
        fail("Unable to load JSON resource");
    }
}
Also used : InputStreamReader(java.io.InputStreamReader) InputStream(java.io.InputStream) RandomCutForest(com.amazon.randomcutforest.RandomCutForest) RandomCutForestState(com.amazon.randomcutforest.state.RandomCutForestState) IOException(java.io.IOException) Random(java.util.Random) RandomCutForestMapper(com.amazon.randomcutforest.state.RandomCutForestMapper) BufferedReader(java.io.BufferedReader) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest) MethodSource(org.junit.jupiter.params.provider.MethodSource)

Example 8 with RandomCutForestState

use of com.amazon.randomcutforest.state.RandomCutForestState in project random-cut-forest-by-aws by aws.

the class V1JsonToV3StateConverterTest method testMerge.

@ParameterizedTest
@MethodSource("args")
public void testMerge(V1JsonResource jsonResource, Precision precision) {
    String resource = jsonResource.getResource();
    try (InputStream is = V1JsonToV3StateConverterTest.class.getResourceAsStream(jsonResource.getResource());
        BufferedReader rr = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) {
        StringBuilder b = new StringBuilder();
        String line;
        while ((line = rr.readLine()) != null) {
            b.append(line);
        }
        String json = b.toString();
        int number = new Random().nextInt(10) + 1;
        int testNumberOfTrees = Math.min(100, 1 + new Random().nextInt(number * jsonResource.getNumberOfTrees() - 1));
        ArrayList<String> models = new ArrayList<>();
        for (int i = 0; i < number; i++) {
            models.add(json);
        }
        RandomCutForestState state = converter.convert(models, testNumberOfTrees, precision).get();
        assertEquals(jsonResource.getDimensions(), state.getDimensions());
        assertEquals(testNumberOfTrees, state.getNumberOfTrees());
        assertEquals(jsonResource.getSampleSize(), state.getSampleSize());
        RandomCutForest forest = new RandomCutForestMapper().toModel(state, 0);
        assertEquals(jsonResource.getDimensions(), forest.getDimensions());
        assertEquals(testNumberOfTrees, forest.getNumberOfTrees());
        assertEquals(jsonResource.getSampleSize(), forest.getSampleSize());
        // perform a simple validation of the deserialized forest by update and scoring
        // with a few points
        Random random = new Random(0);
        for (int i = 0; i < 100; i++) {
            double[] point = getPoint(jsonResource.getDimensions(), random);
            double score = forest.getAnomalyScore(point);
            assertTrue(score > 0);
            forest.update(point);
        }
        int expectedSize = (int) Math.floor(1.0 * testNumberOfTrees * json.length() / (number * jsonResource.getNumberOfTrees()));
        String newString = new ObjectMapper().writeValueAsString(new RandomCutForestMapper().toState(forest));
        System.out.println(" Copied " + number + " times, old number of trees " + jsonResource.getNumberOfTrees() + ", new trees " + testNumberOfTrees + ", Expected Old size " + expectedSize + ", new Size " + newString.length());
    } catch (IOException e) {
        fail("Unable to load JSON resource");
    }
}
Also used : InputStreamReader(java.io.InputStreamReader) InputStream(java.io.InputStream) RandomCutForest(com.amazon.randomcutforest.RandomCutForest) ArrayList(java.util.ArrayList) RandomCutForestState(com.amazon.randomcutforest.state.RandomCutForestState) IOException(java.io.IOException) Random(java.util.Random) RandomCutForestMapper(com.amazon.randomcutforest.state.RandomCutForestMapper) BufferedReader(java.io.BufferedReader) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest) MethodSource(org.junit.jupiter.params.provider.MethodSource)

Example 9 with RandomCutForestState

use of com.amazon.randomcutforest.state.RandomCutForestState in project random-cut-forest-by-aws by aws.

the class RandomCutForestShingledFunctionalTest method testUpdate.

@Test
public void testUpdate() {
    int dimensions = 10;
    RandomCutForest forest = RandomCutForest.builder().numberOfTrees(100).compact(true).dimensions(dimensions).randomSeed(0).sampleSize(200).precision(Precision.FLOAT_32).build();
    double[][] trainingData = genShingledData(1000, dimensions, 0);
    double[][] testData = genShingledData(100, dimensions, 1);
    for (int i = 0; i < testData.length; i++) {
        RandomCutForestMapper mapper = new RandomCutForestMapper();
        mapper.setSaveExecutorContextEnabled(true);
        mapper.setSaveTreeStateEnabled(true);
        double score = forest.getAnomalyScore(testData[i]);
        forest.update(testData[i]);
        RandomCutForestState forestState = mapper.toState(forest);
        forest = mapper.toModel(forestState);
    }
}
Also used : RandomCutForestMapper(com.amazon.randomcutforest.state.RandomCutForestMapper) RandomCutForestState(com.amazon.randomcutforest.state.RandomCutForestState) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 10 with RandomCutForestState

use of com.amazon.randomcutforest.state.RandomCutForestState in project random-cut-forest-by-aws by aws.

the class StateMapperShingledBenchmark method roundTripFromProtostuff.

@Benchmark
@OperationsPerInvocation(NUM_TEST_SAMPLES)
public byte[] roundTripFromProtostuff(BenchmarkState state, Blackhole blackhole) {
    bytes = state.protostuff;
    double[][] testData = state.testData;
    for (int i = 0; i < NUM_TEST_SAMPLES; i++) {
        Schema<RandomCutForestState> schema = RuntimeSchema.getSchema(RandomCutForestState.class);
        RandomCutForestState forestState = schema.newMessage();
        ProtostuffIOUtil.mergeFrom(bytes, forestState, schema);
        RandomCutForestMapper mapper = new RandomCutForestMapper();
        mapper.setSaveExecutorContextEnabled(true);
        mapper.setSaveTreeStateEnabled(state.saveTreeState);
        RandomCutForest forest = mapper.toModel(forestState);
        double score = forest.getAnomalyScore(testData[i]);
        blackhole.consume(score);
        forest.update(testData[i]);
        forestState = mapper.toState(forest);
        LinkedBuffer buffer = LinkedBuffer.allocate(512);
        try {
            bytes = ProtostuffIOUtil.toByteArray(forestState, schema, buffer);
        } finally {
            buffer.clear();
        }
    }
    return bytes;
}
Also used : LinkedBuffer(io.protostuff.LinkedBuffer) RandomCutForestMapper(com.amazon.randomcutforest.state.RandomCutForestMapper) RandomCutForestState(com.amazon.randomcutforest.state.RandomCutForestState) Benchmark(org.openjdk.jmh.annotations.Benchmark) OperationsPerInvocation(org.openjdk.jmh.annotations.OperationsPerInvocation)

Aggregations

RandomCutForestState (com.amazon.randomcutforest.state.RandomCutForestState)17 RandomCutForestMapper (com.amazon.randomcutforest.state.RandomCutForestMapper)14 RandomCutForest (com.amazon.randomcutforest.RandomCutForest)8 Benchmark (org.openjdk.jmh.annotations.Benchmark)6 OperationsPerInvocation (org.openjdk.jmh.annotations.OperationsPerInvocation)6 Precision (com.amazon.randomcutforest.config.Precision)5 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)5 LinkedBuffer (io.protostuff.LinkedBuffer)5 NormalMixtureTestData (com.amazon.randomcutforest.testutils.NormalMixtureTestData)3 IOException (java.io.IOException)3 Random (java.util.Random)3 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)3 BufferedReader (java.io.BufferedReader)2 InputStream (java.io.InputStream)2 InputStreamReader (java.io.InputStreamReader)2 ArrayList (java.util.ArrayList)2 Test (org.junit.jupiter.api.Test)2 MethodSource (org.junit.jupiter.params.provider.MethodSource)2 CommonUtils.checkArgument (com.amazon.randomcutforest.CommonUtils.checkArgument)1 CompactSampler (com.amazon.randomcutforest.sampler.CompactSampler)1