Search in sources :

Example 16 with INodeView

use of com.amazon.randomcutforest.tree.INodeView in project random-cut-forest-by-aws by aws.

the class SimpleInterpolationVisitorTest method testAcceptLeafNotEquals.

@Test
public void testAcceptLeafNotEquals() {
    float[] point = { 1.0f, 9.0f, 4.0f };
    float[] anotherPoint = { 4.0f, 5.0f, 6.0f };
    INodeView leafNode = mock(NodeView.class);
    when(leafNode.getLeafPoint()).thenReturn(anotherPoint);
    when(leafNode.getBoundingBox()).thenReturn(new BoundingBox(anotherPoint, anotherPoint));
    when(leafNode.getMass()).thenReturn(4);
    int leafDepth = 100;
    int sampleSize = 99;
    SimpleInterpolationVisitor visitor = new SimpleInterpolationVisitor(point, sampleSize, 1, false);
    visitor.acceptLeaf(leafNode, leafDepth);
    InterpolationMeasure result = visitor.getResult();
    double expectedSumOfNewRange = 3.0 + 4.0 + 2.0;
    double[] expectedDifferenceInRangeVector = { 0.0, 3.0, 4.0, 0.0, 0.0, 2.0 };
    double[] expectedProbVector = Arrays.stream(expectedDifferenceInRangeVector).map(x -> x / expectedSumOfNewRange).toArray();
    double[] expectedmeasure = Arrays.stream(expectedProbVector).toArray();
    double[] expectedDistances = new double[2 * point.length];
    for (int i = 0; i < 2 * point.length; i++) {
        expectedDistances[i] = expectedProbVector[i] * expectedDifferenceInRangeVector[i];
    }
    for (int i = 0; i < 2 * point.length; i++) {
        expectedmeasure[i] = expectedmeasure[i] * 5;
    }
    for (int i = 0; i < point.length; i++) {
        assertEquals(expectedProbVector[2 * i], result.probMass.high[i]);
        assertEquals(expectedProbVector[2 * i + 1], result.probMass.low[i]);
        assertEquals(expectedmeasure[2 * i], result.measure.high[i]);
        assertEquals(expectedmeasure[2 * i + 1], result.measure.low[i]);
        assertEquals(expectedDistances[2 * i], result.distances.high[i]);
        assertEquals(expectedDistances[2 * i + 1], result.distances.low[i]);
    }
}
Also used : Assertions.assertArrayEquals(org.junit.jupiter.api.Assertions.assertArrayEquals) ArgumentMatchers.any(org.mockito.ArgumentMatchers.any) Test(org.junit.jupiter.api.Test) BoundingBox(com.amazon.randomcutforest.tree.BoundingBox) Arrays(java.util.Arrays) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) INodeView(com.amazon.randomcutforest.tree.INodeView) InterpolationMeasure(com.amazon.randomcutforest.returntypes.InterpolationMeasure) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) Mockito.when(org.mockito.Mockito.when) NodeView(com.amazon.randomcutforest.tree.NodeView) Mockito.mock(org.mockito.Mockito.mock) BoundingBox(com.amazon.randomcutforest.tree.BoundingBox) InterpolationMeasure(com.amazon.randomcutforest.returntypes.InterpolationMeasure) INodeView(com.amazon.randomcutforest.tree.INodeView) Test(org.junit.jupiter.api.Test)

Example 17 with INodeView

use of com.amazon.randomcutforest.tree.INodeView in project random-cut-forest-by-aws by aws.

the class AnomalyAttributionVisitorTest method testAcceptLeafNotEquals.

@Test
public void testAcceptLeafNotEquals() {
    float[] point = new float[] { 1.1f, -2.2f, 3.3f };
    float[] anotherPoint = new float[] { -4.0f, 5.0f, 6.0f };
    INodeView leafNode = mock(NodeView.class);
    when(leafNode.getLeafPoint()).thenReturn(anotherPoint);
    when(leafNode.getBoundingBox()).thenReturn(new BoundingBox(anotherPoint, anotherPoint));
    int leafDepth = 100;
    int leafMass = 4;
    when(leafNode.getMass()).thenReturn(leafMass);
    int treeMass = 21;
    AnomalyAttributionVisitor visitor = new AnomalyAttributionVisitor(point, treeMass, 0);
    visitor.acceptLeaf(leafNode, leafDepth);
    double expectedScoreSum = defaultScoreUnseenFunction(leafDepth, leafMass);
    double sumOfNewRange = (1.1 - (-4.0)) + (5.0 - (-2.2)) + (6.0 - 3.3);
    DiVector result = visitor.getResult();
    assertEquals(defaultScalarNormalizerFunction(expectedScoreSum * (1.1 - (-4.0)) / sumOfNewRange, treeMass), result.high[0], EPSILON);
    assertEquals(0.0, result.low[0]);
    assertEquals(0.0, result.high[1]);
    assertEquals(defaultScalarNormalizerFunction(expectedScoreSum * (5.0 - (-2.2)) / sumOfNewRange, treeMass), result.low[1], EPSILON);
    assertEquals(0.0, result.high[2]);
    assertEquals(defaultScalarNormalizerFunction(expectedScoreSum * (6.0 - 3.3) / sumOfNewRange, treeMass), result.low[2], EPSILON);
    visitor = new AnomalyAttributionVisitor(point, treeMass, 3);
    visitor.acceptLeaf(leafNode, leafDepth);
    result = visitor.getResult();
    assertEquals(defaultScalarNormalizerFunction(expectedScoreSum * (1.1 - (-4.0)) / sumOfNewRange, treeMass), result.high[0], EPSILON);
    assertEquals(0.0, result.low[0]);
    assertEquals(0.0, result.high[1]);
    assertEquals(defaultScalarNormalizerFunction(expectedScoreSum * (5.0 - (-2.2)) / sumOfNewRange, treeMass), result.low[1], EPSILON);
    assertEquals(0.0, result.high[2]);
    assertEquals(defaultScalarNormalizerFunction(expectedScoreSum * (6.0 - 3.3) / sumOfNewRange, treeMass), result.low[2], EPSILON);
    visitor = new AnomalyAttributionVisitor(point, treeMass, 4);
    visitor.acceptLeaf(leafNode, leafDepth);
    double expectedScore = expectedScoreSum / (2 * point.length);
    result = visitor.getResult();
    for (int i = 0; i < point.length; i++) {
        assertEquals(defaultScalarNormalizerFunction(expectedScore, treeMass), result.low[i], EPSILON);
        assertEquals(defaultScalarNormalizerFunction(expectedScore, treeMass), result.high[i], EPSILON);
    }
}
Also used : DiVector(com.amazon.randomcutforest.returntypes.DiVector) BoundingBox(com.amazon.randomcutforest.tree.BoundingBox) INodeView(com.amazon.randomcutforest.tree.INodeView) Test(org.junit.jupiter.api.Test)

Aggregations

INodeView (com.amazon.randomcutforest.tree.INodeView)17 Test (org.junit.jupiter.api.Test)17 BoundingBox (com.amazon.randomcutforest.tree.BoundingBox)14 InterpolationMeasure (com.amazon.randomcutforest.returntypes.InterpolationMeasure)4 NodeView (com.amazon.randomcutforest.tree.NodeView)4 DiVector (com.amazon.randomcutforest.returntypes.DiVector)3 Neighbor (com.amazon.randomcutforest.returntypes.Neighbor)3 Arrays (java.util.Arrays)3 Assertions.assertArrayEquals (org.junit.jupiter.api.Assertions.assertArrayEquals)3 Assertions.assertEquals (org.junit.jupiter.api.Assertions.assertEquals)3 Assertions.assertFalse (org.junit.jupiter.api.Assertions.assertFalse)3 ArgumentMatchers.any (org.mockito.ArgumentMatchers.any)3 Mockito.mock (org.mockito.Mockito.mock)3 Mockito.when (org.mockito.Mockito.when)3 IBoundingBoxView (com.amazon.randomcutforest.tree.IBoundingBoxView)2 HashMap (java.util.HashMap)2