Search in sources :

Example 1 with DensityOutput

use of com.amazon.randomcutforest.returntypes.DensityOutput in project random-cut-forest-by-aws by aws.

the class DynamicDensity method run.

/**
 * plot the dynamic_density_example using any tool in gnuplot one can plot the
 * directions to higher density via do for [i=0:358:2] {plot
 * "dynamic_density_example" index (i+1) u 1:2:3:4 w vectors t ""} or the raw
 * density at the points via do for [i=0:358:2] {plot "dynamic_density_example"
 * index i w p pt 7 palette t ""}
 *
 * @throws Exception
 */
@Override
public void run() throws Exception {
    int newDimensions = 2;
    long randomSeed = 123;
    RandomCutForest newForest = RandomCutForest.builder().numberOfTrees(100).sampleSize(256).dimensions(newDimensions).randomSeed(randomSeed).timeDecay(1.0 / 800).centerOfMassEnabled(true).build();
    String name = "dynamic_density_example";
    BufferedWriter file = new BufferedWriter(new FileWriter(name));
    double[][] data = generate(1000);
    double[] queryPoint;
    for (int degree = 0; degree < 360; degree += 2) {
        for (double[] datum : data) {
            newForest.update(rotateClockWise(datum, -2 * PI * degree / 360));
        }
        for (double[] datum : data) {
            queryPoint = rotateClockWise(datum, -2 * PI * degree / 360);
            DensityOutput density = newForest.getSimpleDensity(queryPoint);
            double value = density.getDensity(0.001, 2);
            file.append(queryPoint[0] + " " + queryPoint[1] + " " + value + "\n");
        }
        file.append("\n");
        file.append("\n");
        for (double x = -0.95; x < 1; x += 0.1) {
            for (double y = -0.95; y < 1; y += 0.1) {
                DensityOutput density = newForest.getSimpleDensity(new double[] { x, y });
                double aboveInY = density.getDirectionalDensity(0.001, 2).low[1];
                double belowInY = density.getDirectionalDensity(0.001, 2).high[1];
                double toTheLeft = density.getDirectionalDensity(0.001, 2).high[0];
                double toTheRight = density.getDirectionalDensity(0.001, 2).low[0];
                double len = Math.sqrt(aboveInY * aboveInY + belowInY * belowInY + toTheLeft * toTheLeft + toTheRight * toTheRight);
                file.append(x + " " + y + " " + ((toTheRight - toTheLeft) * 0.05 / len) + " " + ((aboveInY - belowInY) * 0.05 / len) + "\n");
            }
        }
        file.append("\n");
        file.append("\n");
    }
    file.close();
}
Also used : DensityOutput(com.amazon.randomcutforest.returntypes.DensityOutput) RandomCutForest(com.amazon.randomcutforest.RandomCutForest) FileWriter(java.io.FileWriter) BufferedWriter(java.io.BufferedWriter)

Example 2 with DensityOutput

use of com.amazon.randomcutforest.returntypes.DensityOutput in project random-cut-forest-by-aws by aws.

the class DynamicPointSetFunctionalTest method movingDensity.

@Test
public void movingDensity() {
    int newDimensions = 2;
    randomSeed = 123;
    RandomCutForest newForest = RandomCutForest.builder().dimensions(newDimensions).randomSeed(randomSeed).timeDecay(1.0 / 800).centerOfMassEnabled(true).storeSequenceIndexesEnabled(true).build();
    double[][] data = generateFan(1000, 3);
    double[] queryPoint = new double[] { 0.7, 0 };
    for (int degree = 0; degree < 360; degree += 2) {
        for (int j = 0; j < data.length; j++) {
            newForest.update(rotateClockWise(data[j], 2 * PI * degree / 360));
        }
        DensityOutput density = newForest.getSimpleDensity(queryPoint);
        double value = density.getDensity(0.001, 2);
        if ((degree <= 60) || ((degree >= 120) && (degree <= 180)) || ((degree >= 240) && (degree <= 300)))
            // the fan is above at 90,210,330
            assertTrue(value < 0.8);
        if (((degree >= 75) && (degree <= 105)) || ((degree >= 195) && (degree <= 225)) || ((degree >= 315) && (degree <= 345)))
            assertTrue(value > 0.5);
        // fan is close by
        // intentionally 0.5 is below 0.8 for a robust test
        // Testing for directionality
        // There can be unclear directionality when the
        // blades are right above
        double bladeAboveInY = density.getDirectionalDensity(0.001, 2).low[1];
        double bladeBelowInY = density.getDirectionalDensity(0.001, 2).high[1];
        double bladesToTheLeft = density.getDirectionalDensity(0.001, 2).high[0];
        double bladesToTheRight = density.getDirectionalDensity(0.001, 2).low[0];
        assertEquals(value, bladeAboveInY + bladeBelowInY + bladesToTheLeft + bladesToTheRight, 1E-6);
        // the tests below have a freedom of 10% of the total value
        if (((degree >= 75) && (degree <= 85)) || ((degree >= 195) && (degree <= 205)) || ((degree >= 315) && (degree <= 325))) {
            assertTrue(bladeAboveInY + 0.1 * value > bladeBelowInY);
            assertTrue(bladeAboveInY + 0.1 * value > bladesToTheRight);
        }
        if (((degree >= 95) && (degree <= 105)) || ((degree >= 215) && (degree <= 225)) || ((degree >= 335) && (degree <= 345))) {
            assertTrue(bladeBelowInY + 0.1 * value > bladeAboveInY);
            assertTrue(bladeBelowInY + 0.1 * value > bladesToTheRight);
        }
        if (((degree >= 60) && (degree <= 75)) || ((degree >= 180) && (degree <= 195)) || ((degree >= 300) && (degree <= 315))) {
            assertTrue(bladeAboveInY + 0.1 * value > bladesToTheLeft);
            assertTrue(bladeAboveInY + 0.1 * value > bladesToTheRight);
        }
        if (((degree >= 105) && (degree <= 120)) || ((degree >= 225) && (degree <= 240)) || (degree >= 345)) {
            assertTrue(bladeBelowInY + 0.1 * value > bladesToTheLeft);
            assertTrue(bladeBelowInY + 0.1 * value > bladesToTheRight);
        }
        // fans are farthest to the left at 30,150 and 270
        if (((degree >= 15) && (degree <= 45)) || ((degree >= 135) && (degree <= 165)) || ((degree >= 255) && (degree <= 285))) {
            assertTrue(bladesToTheLeft + 0.1 * value > bladeAboveInY + bladeBelowInY + bladesToTheRight);
            assertTrue(bladeAboveInY + bladeBelowInY + 0.1 * value > bladesToTheRight);
        }
    }
}
Also used : DensityOutput(com.amazon.randomcutforest.returntypes.DensityOutput) Test(org.junit.jupiter.api.Test)

Example 3 with DensityOutput

use of com.amazon.randomcutforest.returntypes.DensityOutput in project random-cut-forest-by-aws by aws.

the class RandomCutForestTest method testGetSimpleDensity.

@Test
public void testGetSimpleDensity() {
    float[] point = { 12.3f, -45.6f };
    DensityOutput zero = new DensityOutput(dimensions, sampleSize);
    assertFalse(forest.samplersFull());
    DensityOutput result = forest.getSimpleDensity(point);
    assertEquals(zero.getDensity(), result.getDensity(), EPSILON);
    doReturn(true).when(forest).samplersFull();
    List<InterpolationMeasure> intermediateResults = new ArrayList<>();
    for (int i = 0; i < numberOfTrees; i++) {
        InterpolationMeasure treeResult = new InterpolationMeasure(dimensions, sampleSize);
        for (int j = 0; j < dimensions; j++) {
            treeResult.measure.high[j] = Math.random();
            treeResult.measure.low[j] = Math.random();
            treeResult.distances.high[j] = Math.random();
            treeResult.distances.low[j] = Math.random();
            treeResult.probMass.high[j] = Math.random();
            treeResult.probMass.low[j] = Math.random();
        }
        SamplerPlusTree<Integer, float[]> component = (SamplerPlusTree<Integer, float[]>) components.get(i);
        ITree<Integer, float[]> tree = component.getTree();
        when(tree.traverse(aryEq(point), any(VisitorFactory.class))).thenReturn(treeResult);
        intermediateResults.add(treeResult);
    }
    Collector<InterpolationMeasure, ?, InterpolationMeasure> collector = InterpolationMeasure.collector(dimensions, sampleSize, numberOfTrees);
    DensityOutput expectedResult = new DensityOutput(intermediateResults.stream().collect(collector));
    result = forest.getSimpleDensity(point);
    assertEquals(expectedResult.getDensity(), result.getDensity(), EPSILON);
}
Also used : DensityOutput(com.amazon.randomcutforest.returntypes.DensityOutput) ArrayList(java.util.ArrayList) InterpolationMeasure(com.amazon.randomcutforest.returntypes.InterpolationMeasure) SamplerPlusTree(com.amazon.randomcutforest.executor.SamplerPlusTree) Test(org.junit.jupiter.api.Test)

Example 4 with DensityOutput

use of com.amazon.randomcutforest.returntypes.DensityOutput in project random-cut-forest-by-aws by aws.

the class RandomCutForestFunctionalTest method testSimpleDensityWhenSamplerNotFullThenDensityIsZero.

@ParameterizedTest
@ArgumentsSource(TestForestProvider.class)
public void testSimpleDensityWhenSamplerNotFullThenDensityIsZero(RandomCutForest forest) {
    RandomCutForest forestSpy = spy(forest);
    when(forestSpy.samplersFull()).thenReturn(false);
    DensityOutput output = forestSpy.getSimpleDensity(new double[] { 0.0, 0.0, 0.0 });
    assertEquals(0, output.getDensity(0.001, 3));
}
Also used : DensityOutput(com.amazon.randomcutforest.returntypes.DensityOutput) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest) ArgumentsSource(org.junit.jupiter.params.provider.ArgumentsSource)

Example 5 with DensityOutput

use of com.amazon.randomcutforest.returntypes.DensityOutput in project random-cut-forest-by-aws by aws.

the class RandomCutForest method getSimpleDensity.

public DensityOutput getSimpleDensity(float[] point) {
    if (!samplersFull()) {
        return new DensityOutput(dimensions, sampleSize);
    }
    IVisitorFactory<InterpolationMeasure> visitorFactory = new VisitorFactory<>((tree, y) -> new SimpleInterpolationVisitor(tree.projectToTree(y), sampleSize, 1.0, centerOfMassEnabled), (tree, x) -> x.lift(tree::liftFromTree));
    Collector<InterpolationMeasure, ?, InterpolationMeasure> collector = InterpolationMeasure.collector(dimensions, sampleSize, numberOfTrees);
    return new DensityOutput(traverseForest(transformToShingledPoint(point), visitorFactory, collector));
}
Also used : DensityOutput(com.amazon.randomcutforest.returntypes.DensityOutput) SimpleInterpolationVisitor(com.amazon.randomcutforest.interpolation.SimpleInterpolationVisitor) InterpolationMeasure(com.amazon.randomcutforest.returntypes.InterpolationMeasure)

Aggregations

DensityOutput (com.amazon.randomcutforest.returntypes.DensityOutput)8 InterpolationMeasure (com.amazon.randomcutforest.returntypes.InterpolationMeasure)2 Test (org.junit.jupiter.api.Test)2 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)2 ArgumentsSource (org.junit.jupiter.params.provider.ArgumentsSource)2 Benchmark (org.openjdk.jmh.annotations.Benchmark)2 OperationsPerInvocation (org.openjdk.jmh.annotations.OperationsPerInvocation)2 RandomCutForest (com.amazon.randomcutforest.RandomCutForest)1 SamplerPlusTree (com.amazon.randomcutforest.executor.SamplerPlusTree)1 SimpleInterpolationVisitor (com.amazon.randomcutforest.interpolation.SimpleInterpolationVisitor)1 BufferedWriter (java.io.BufferedWriter)1 FileWriter (java.io.FileWriter)1 ArrayList (java.util.ArrayList)1