Search in sources :

Example 1 with InterpolationMeasure

use of com.amazon.randomcutforest.returntypes.InterpolationMeasure in project random-cut-forest-by-aws by aws.

the class SimpleInterpolationVisitorTest method testAccept.

@Test
public void testAccept() {
    float[] pointToScore = { 0.0f, 0.0f };
    int sampleSize = 50;
    SimpleInterpolationVisitor visitor = new SimpleInterpolationVisitor(pointToScore, sampleSize, 1, false);
    INodeView leafNode = mock(NodeView.class);
    float[] point = new float[] { 1.0f, -2.0f };
    when(leafNode.getLeafPoint()).thenReturn(point);
    when(leafNode.getBoundingBox()).thenReturn(new BoundingBox(point, point));
    int leafMass = 3;
    when(leafNode.getMass()).thenReturn(leafMass);
    int depth = 4;
    visitor.acceptLeaf(leafNode, depth);
    InterpolationMeasure result = visitor.getResult();
    double expectedSumOfNewRange = 1.0 + 2.0;
    double[] expectedDifferenceInRangeVector = { 0.0, 1.0, 2.0, 0.0 };
    double[] expectedProbVector = Arrays.stream(expectedDifferenceInRangeVector).map(x -> x / expectedSumOfNewRange).toArray();
    double[] expectedNumPts = Arrays.stream(expectedProbVector).toArray();
    double[] expectedDistances = new double[2 * pointToScore.length];
    for (int i = 0; i < 2 * pointToScore.length; i++) {
        expectedDistances[i] = expectedProbVector[i] * expectedDifferenceInRangeVector[i];
    }
    for (int i = 0; i < 2 * pointToScore.length; i++) {
        expectedNumPts[i] = expectedNumPts[i] * 4;
    }
    for (int i = 0; i < pointToScore.length; i++) {
        assertEquals(expectedProbVector[2 * i], result.probMass.high[i]);
        assertEquals(expectedProbVector[2 * i + 1], result.probMass.low[i]);
        assertEquals(expectedNumPts[2 * i], result.measure.high[i]);
        assertEquals(expectedNumPts[2 * i + 1], result.measure.low[i]);
        assertEquals(expectedDistances[2 * i], result.distances.high[i]);
        assertEquals(expectedDistances[2 * i + 1], result.distances.low[i]);
    }
    // parent does not contain pointToScore
    depth--;
    INodeView sibling = mock(NodeView.class);
    int siblingMass = 2;
    when(sibling.getMass()).thenReturn(siblingMass);
    INodeView parent = mock(NodeView.class);
    int parentMass = leafMass + siblingMass;
    when(parent.getMass()).thenReturn(parentMass);
    when(parent.getBoundingBox()).thenReturn(new BoundingBox(point, new float[] { 2.0f, -0.5f }));
    visitor.accept(parent, depth);
    result = visitor.getResult();
    double expectedSumOfNewRange2 = 2.0 + 2.0;
    double expectedProbOfCut2 = (1.0 + 0.5) / expectedSumOfNewRange2;
    double[] expectedDifferenceInRangeVector2 = { 0.0, 1.0, 0.5, 0.0 };
    double[] expectedDirectionalDistanceVector2 = { 0.0, 2.0, 2.0, 0.0 };
    for (int i = 0; i < 2 * pointToScore.length; i++) {
        double prob = expectedDifferenceInRangeVector2[i] / expectedSumOfNewRange2;
        expectedProbVector[i] = prob + (1 - expectedProbOfCut2) * expectedProbVector[i];
        expectedNumPts[i] = prob * (1 + parent.getMass()) + (1 - expectedProbOfCut2) * expectedNumPts[i];
        expectedDistances[i] = prob * expectedDirectionalDistanceVector2[i] + (1 - expectedProbOfCut2) * expectedDistances[i];
    }
    for (int i = 0; i < pointToScore.length; i++) {
        assertEquals(expectedProbVector[2 * i], result.probMass.high[i]);
        assertEquals(expectedProbVector[2 * i + 1], result.probMass.low[i]);
        assertEquals(expectedNumPts[2 * i], result.measure.high[i]);
        assertEquals(expectedNumPts[2 * i + 1], result.measure.low[i]);
        assertEquals(expectedDistances[2 * i], result.distances.high[i]);
        assertEquals(expectedDistances[2 * i + 1], result.distances.low[i]);
    }
    // grandparent contains pointToScore
    assertFalse(visitor.pointInsideBox);
    depth--;
}
Also used : Assertions.assertArrayEquals(org.junit.jupiter.api.Assertions.assertArrayEquals) ArgumentMatchers.any(org.mockito.ArgumentMatchers.any) Test(org.junit.jupiter.api.Test) BoundingBox(com.amazon.randomcutforest.tree.BoundingBox) Arrays(java.util.Arrays) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) INodeView(com.amazon.randomcutforest.tree.INodeView) InterpolationMeasure(com.amazon.randomcutforest.returntypes.InterpolationMeasure) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) Mockito.when(org.mockito.Mockito.when) NodeView(com.amazon.randomcutforest.tree.NodeView) Mockito.mock(org.mockito.Mockito.mock) BoundingBox(com.amazon.randomcutforest.tree.BoundingBox) InterpolationMeasure(com.amazon.randomcutforest.returntypes.InterpolationMeasure) INodeView(com.amazon.randomcutforest.tree.INodeView) Test(org.junit.jupiter.api.Test)

Example 2 with InterpolationMeasure

use of com.amazon.randomcutforest.returntypes.InterpolationMeasure in project random-cut-forest-by-aws by aws.

the class RandomCutForestTest method testGetSimpleDensity.

@Test
public void testGetSimpleDensity() {
    float[] point = { 12.3f, -45.6f };
    DensityOutput zero = new DensityOutput(dimensions, sampleSize);
    assertFalse(forest.samplersFull());
    DensityOutput result = forest.getSimpleDensity(point);
    assertEquals(zero.getDensity(), result.getDensity(), EPSILON);
    doReturn(true).when(forest).samplersFull();
    List<InterpolationMeasure> intermediateResults = new ArrayList<>();
    for (int i = 0; i < numberOfTrees; i++) {
        InterpolationMeasure treeResult = new InterpolationMeasure(dimensions, sampleSize);
        for (int j = 0; j < dimensions; j++) {
            treeResult.measure.high[j] = Math.random();
            treeResult.measure.low[j] = Math.random();
            treeResult.distances.high[j] = Math.random();
            treeResult.distances.low[j] = Math.random();
            treeResult.probMass.high[j] = Math.random();
            treeResult.probMass.low[j] = Math.random();
        }
        SamplerPlusTree<Integer, float[]> component = (SamplerPlusTree<Integer, float[]>) components.get(i);
        ITree<Integer, float[]> tree = component.getTree();
        when(tree.traverse(aryEq(point), any(VisitorFactory.class))).thenReturn(treeResult);
        intermediateResults.add(treeResult);
    }
    Collector<InterpolationMeasure, ?, InterpolationMeasure> collector = InterpolationMeasure.collector(dimensions, sampleSize, numberOfTrees);
    DensityOutput expectedResult = new DensityOutput(intermediateResults.stream().collect(collector));
    result = forest.getSimpleDensity(point);
    assertEquals(expectedResult.getDensity(), result.getDensity(), EPSILON);
}
Also used : DensityOutput(com.amazon.randomcutforest.returntypes.DensityOutput) ArrayList(java.util.ArrayList) InterpolationMeasure(com.amazon.randomcutforest.returntypes.InterpolationMeasure) SamplerPlusTree(com.amazon.randomcutforest.executor.SamplerPlusTree) Test(org.junit.jupiter.api.Test)

Example 3 with InterpolationMeasure

use of com.amazon.randomcutforest.returntypes.InterpolationMeasure in project random-cut-forest-by-aws by aws.

the class RandomCutForest method getSimpleDensity.

public DensityOutput getSimpleDensity(float[] point) {
    if (!samplersFull()) {
        return new DensityOutput(dimensions, sampleSize);
    }
    IVisitorFactory<InterpolationMeasure> visitorFactory = new VisitorFactory<>((tree, y) -> new SimpleInterpolationVisitor(tree.projectToTree(y), sampleSize, 1.0, centerOfMassEnabled), (tree, x) -> x.lift(tree::liftFromTree));
    Collector<InterpolationMeasure, ?, InterpolationMeasure> collector = InterpolationMeasure.collector(dimensions, sampleSize, numberOfTrees);
    return new DensityOutput(traverseForest(transformToShingledPoint(point), visitorFactory, collector));
}
Also used : DensityOutput(com.amazon.randomcutforest.returntypes.DensityOutput) SimpleInterpolationVisitor(com.amazon.randomcutforest.interpolation.SimpleInterpolationVisitor) InterpolationMeasure(com.amazon.randomcutforest.returntypes.InterpolationMeasure)

Example 4 with InterpolationMeasure

use of com.amazon.randomcutforest.returntypes.InterpolationMeasure in project random-cut-forest-by-aws by aws.

the class SimpleInterpolationVisitorTest method testAcceptEqualsLeafPoint.

@Test
public void testAcceptEqualsLeafPoint() {
    float[] pointToScore = { 0.0f, 0.0f };
    int sampleSize = 50;
    SimpleInterpolationVisitor visitor = new SimpleInterpolationVisitor(pointToScore, sampleSize, 1, false);
    float[] point = Arrays.copyOf(pointToScore, pointToScore.length);
    INodeView node = mock(NodeView.class);
    when(node.getLeafPoint()).thenReturn(point);
    when(node.getBoundingBox()).thenReturn(new BoundingBox(point, point));
    when(node.getMass()).thenReturn(1);
    int depth = 2;
    visitor.acceptLeaf(node, depth);
    InterpolationMeasure result = visitor.getResult();
    double[] expected = new double[point.length];
    Arrays.fill(expected, 0.5 * (1 + node.getMass()) / point.length);
    assertArrayEquals(expected, result.measure.high);
    assertArrayEquals(expected, result.measure.low);
    Arrays.fill(expected, 0.5 / point.length);
    assertArrayEquals(expected, result.probMass.high);
    assertArrayEquals(expected, result.probMass.low);
    Arrays.fill(expected, 0.0);
    assertArrayEquals(expected, result.distances.high);
    assertArrayEquals(expected, result.distances.low);
    depth--;
    float[] siblingPoint = { 1.0f, -2.0f };
    INodeView sibling = mock(NodeView.class);
    int siblingMass = 2;
    when(sibling.getMass()).thenReturn(siblingMass);
    INodeView parent = mock(NodeView.class);
    when(parent.getMass()).thenReturn(1 + siblingMass);
    BoundingBox boundingBox = new BoundingBox(point, siblingPoint);
    when(parent.getBoundingBox()).thenReturn(boundingBox);
    when(parent.getSiblingBoundingBox(any())).thenReturn(new BoundingBox(siblingPoint));
    visitor.accept(parent, depth);
    result = visitor.getResult();
    // compute using shadow box (sibling leaf node at {1.0, -2.0} and parent
    // bounding box
    double[] directionalDistance = { 0.0, 1.0, 2.0, 0.0 };
    double[] differenceInRange = { 0.0, 1.0, 2.0, 0.0 };
    double sumOfNewRange = 1.0 + 2.0;
    double[] probVector = Arrays.stream(differenceInRange).map(x -> x / sumOfNewRange).toArray();
    expected = new double[2 * pointToScore.length];
    for (int i = 0; i < expected.length; i++) {
        expected[i] = probVector[i] * (1 + node.getMass() + parent.getMass());
    }
    for (int i = 0; i < pointToScore.length; i++) {
        assertEquals(expected[2 * i], result.measure.high[i]);
        assertEquals(expected[2 * i + 1], result.measure.low[i]);
    }
    for (int i = 0; i < expected.length; i++) {
        expected[i] = probVector[i];
    }
    for (int i = 0; i < pointToScore.length; i++) {
        assertEquals(expected[2 * i], result.probMass.high[i]);
        assertEquals(expected[2 * i + 1], result.probMass.low[i]);
    }
    for (int i = 0; i < expected.length; i++) {
        expected[i] = probVector[i] * directionalDistance[i];
    }
    for (int i = 0; i < pointToScore.length; i++) {
        assertEquals(expected[2 * i], result.distances.high[i]);
        assertEquals(expected[2 * i + 1], result.distances.low[i]);
    }
}
Also used : Assertions.assertArrayEquals(org.junit.jupiter.api.Assertions.assertArrayEquals) ArgumentMatchers.any(org.mockito.ArgumentMatchers.any) Test(org.junit.jupiter.api.Test) BoundingBox(com.amazon.randomcutforest.tree.BoundingBox) Arrays(java.util.Arrays) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) INodeView(com.amazon.randomcutforest.tree.INodeView) InterpolationMeasure(com.amazon.randomcutforest.returntypes.InterpolationMeasure) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) Mockito.when(org.mockito.Mockito.when) NodeView(com.amazon.randomcutforest.tree.NodeView) Mockito.mock(org.mockito.Mockito.mock) BoundingBox(com.amazon.randomcutforest.tree.BoundingBox) InterpolationMeasure(com.amazon.randomcutforest.returntypes.InterpolationMeasure) INodeView(com.amazon.randomcutforest.tree.INodeView) Test(org.junit.jupiter.api.Test)

Example 5 with InterpolationMeasure

use of com.amazon.randomcutforest.returntypes.InterpolationMeasure in project random-cut-forest-by-aws by aws.

the class SimpleInterpolationVisitorTest method testNew.

@Test
public void testNew() {
    float[] point = { 1.0f, 2.0f };
    int sampleSize = 9;
    SimpleInterpolationVisitor visitor = new SimpleInterpolationVisitor(point, sampleSize, 1, false);
    assertFalse(visitor.pointInsideBox);
    assertEquals(2, visitor.coordInsideBox.length);
    for (int i = 0; i < point.length; i++) {
        assertFalse(visitor.coordInsideBox[i]);
    }
    InterpolationMeasure output = visitor.getResult();
    double[] zero = new double[point.length];
    assertArrayEquals(zero, output.measure.high);
    assertArrayEquals(zero, output.distances.high);
    assertArrayEquals(zero, output.probMass.high);
    assertArrayEquals(zero, output.measure.low);
    assertArrayEquals(zero, output.distances.low);
    assertArrayEquals(zero, output.probMass.low);
}
Also used : InterpolationMeasure(com.amazon.randomcutforest.returntypes.InterpolationMeasure) Test(org.junit.jupiter.api.Test)

Aggregations

InterpolationMeasure (com.amazon.randomcutforest.returntypes.InterpolationMeasure)7 Test (org.junit.jupiter.api.Test)6 BoundingBox (com.amazon.randomcutforest.tree.BoundingBox)4 INodeView (com.amazon.randomcutforest.tree.INodeView)4 NodeView (com.amazon.randomcutforest.tree.NodeView)3 Arrays (java.util.Arrays)3 Assertions.assertArrayEquals (org.junit.jupiter.api.Assertions.assertArrayEquals)3 Assertions.assertEquals (org.junit.jupiter.api.Assertions.assertEquals)3 Assertions.assertFalse (org.junit.jupiter.api.Assertions.assertFalse)3 ArgumentMatchers.any (org.mockito.ArgumentMatchers.any)3 Mockito.mock (org.mockito.Mockito.mock)3 Mockito.when (org.mockito.Mockito.when)3 DensityOutput (com.amazon.randomcutforest.returntypes.DensityOutput)2 SamplerPlusTree (com.amazon.randomcutforest.executor.SamplerPlusTree)1 SimpleInterpolationVisitor (com.amazon.randomcutforest.interpolation.SimpleInterpolationVisitor)1 ArrayList (java.util.ArrayList)1