Search in sources :

Example 1 with Neighbor

use of com.amazon.randomcutforest.returntypes.Neighbor in project random-cut-forest-by-aws by aws.

the class DynamicPointSetFunctionalTest method movingNeighbors.

@Test
public void movingNeighbors() {
    int newDimensions = 2;
    randomSeed = 123;
    RandomCutForest newForest = RandomCutForest.builder().dimensions(newDimensions).randomSeed(randomSeed).timeDecay(1.0 / 800).centerOfMassEnabled(true).storeSequenceIndexesEnabled(true).build();
    double[][] data = generateFan(1000, 3);
    double[] queryPoint = new double[] { 0.7, 0 };
    for (int degree = 0; degree < 360; degree += 2) {
        for (int j = 0; j < data.length; j++) {
            newForest.update(rotateClockWise(data[j], 2 * PI * degree / 360));
        }
        List<Neighbor> ans = newForest.getNearNeighborsInSample(queryPoint, 1);
        List<Neighbor> closeNeighBors = newForest.getNearNeighborsInSample(queryPoint, 0.1);
        Neighbor best = null;
        if (ans != null) {
            best = ans.get(0);
            for (int j = 1; j < ans.size(); j++) {
                assert (ans.get(j).distance >= best.distance);
            }
        }
        // fan is away at 30, 150 and 270
        if (((degree > 15) && (degree < 45)) || ((degree >= 135) && (degree <= 165)) || ((degree >= 255) && (degree <= 285))) {
            // no close neighbor
            assertTrue(closeNeighBors.size() == 0);
            assertTrue(best.distance > 0.3);
        }
        // fan is overhead at 90, 210 and 330
        if (((degree > 75) && (degree < 105)) || ((degree >= 195) && (degree <= 225)) || ((degree >= 315) && (degree <= 345))) {
            assertTrue(closeNeighBors.size() > 0);
            assertEquals(closeNeighBors.get(0).distance, best.distance, 1E-10);
        }
    }
}
Also used : Neighbor(com.amazon.randomcutforest.returntypes.Neighbor) Test(org.junit.jupiter.api.Test)

Example 2 with Neighbor

use of com.amazon.randomcutforest.returntypes.Neighbor in project random-cut-forest-by-aws by aws.

the class RandomCutForestTest method testGetNearNeighborInSample.

@Test
public void testGetNearNeighborInSample() {
    List<Long> indexes1 = new ArrayList<>();
    indexes1.add(1L);
    indexes1.add(3L);
    List<Long> indexes2 = new ArrayList<>();
    indexes2.add(2L);
    indexes2.add(4L);
    List<Long> indexes4 = new ArrayList<>();
    indexes4.add(1L);
    indexes4.add(3L);
    List<Long> indexes5 = new ArrayList<>();
    indexes5.add(2L);
    indexes5.add(4L);
    Neighbor neighbor1 = new Neighbor(new double[] { 1, 2 }, 5, indexes1);
    when(((SamplerPlusTree<?, ?>) components.get(0)).getTree().traverse(any(float[].class), any(IVisitorFactory.class))).thenReturn(Optional.of(neighbor1));
    Neighbor neighbor2 = new Neighbor(new double[] { 1, 2 }, 5, indexes2);
    when(((SamplerPlusTree<?, ?>) components.get(1)).getTree().traverse(any(float[].class), any(IVisitorFactory.class))).thenReturn(Optional.of(neighbor2));
    when(((SamplerPlusTree<?, ?>) components.get(2)).getTree().traverse(any(float[].class), any(IVisitorFactory.class))).thenReturn(Optional.empty());
    Neighbor neighbor4 = new Neighbor(new double[] { 2, 3 }, 4, indexes4);
    when(((SamplerPlusTree<?, ?>) components.get(3)).getTree().traverse(any(float[].class), any(IVisitorFactory.class))).thenReturn(Optional.of(neighbor4));
    Neighbor neighbor5 = new Neighbor(new double[] { 2, 3 }, 4, indexes5);
    when(((SamplerPlusTree<?, ?>) components.get(4)).getTree().traverse(any(float[].class), any(IVisitorFactory.class))).thenReturn(Optional.of(neighbor5));
    for (int i = 5; i < components.size(); i++) {
        when(((SamplerPlusTree<?, ?>) components.get(i)).getTree().traverse(any(float[].class), any(IVisitorFactory.class))).thenReturn(Optional.empty());
    }
    Whitebox.setInternalState(forest, "storeSequenceIndexesEnabled", true);
    doReturn(true).when(forest).isOutputReady();
    List<Neighbor> neighbors = forest.getNearNeighborsInSample(new double[] { 0, 0 }, 5);
    List<Long> expectedIndexes = Arrays.asList(1L, 2L, 3L, 4L);
    assertEquals(2, neighbors.size());
    assertTrue(neighbors.get(0).point[0] == 2 && neighbors.get(0).point[1] == 3);
    assertEquals(4, neighbors.get(0).distance);
    assertEquals(4, neighbors.get(0).sequenceIndexes.size());
    assertThat(neighbors.get(0).sequenceIndexes, is(expectedIndexes));
    assertTrue(neighbors.get(1).point[0] == 1 && neighbors.get(1).point[1] == 2);
    assertEquals(5, neighbors.get(1).distance);
    assertEquals(4, neighbors.get(1).sequenceIndexes.size());
    assertThat(neighbors.get(1).sequenceIndexes, is(expectedIndexes));
}
Also used : ArrayList(java.util.ArrayList) Neighbor(com.amazon.randomcutforest.returntypes.Neighbor) Test(org.junit.jupiter.api.Test)

Example 3 with Neighbor

use of com.amazon.randomcutforest.returntypes.Neighbor in project random-cut-forest-by-aws by aws.

the class NearNeighborVisitorTest method acceptLeafNotNear.

@Test
public void acceptLeafNotNear() {
    float[] leafPoint = new float[] { 108.8f, 209.9f, -305.5f };
    INodeView leafNode = mock(NodeView.class);
    HashMap<Long, Integer> sequenceIndexes = new HashMap<>();
    sequenceIndexes.put(1234L, 1);
    sequenceIndexes.put(5678L, 1);
    when(leafNode.getLeafPoint()).thenReturn(leafPoint);
    when(leafNode.getLiftedLeafPoint()).thenReturn(leafPoint);
    when(leafNode.getSequenceIndexes()).thenReturn(sequenceIndexes);
    int depth = 12;
    visitor.acceptLeaf(leafNode, depth);
    Optional<Neighbor> optional = visitor.getResult();
    assertFalse(optional.isPresent());
}
Also used : HashMap(java.util.HashMap) Neighbor(com.amazon.randomcutforest.returntypes.Neighbor) INodeView(com.amazon.randomcutforest.tree.INodeView) Test(org.junit.jupiter.api.Test)

Example 4 with Neighbor

use of com.amazon.randomcutforest.returntypes.Neighbor in project random-cut-forest-by-aws by aws.

the class NearNeighborVisitor method acceptLeaf.

/**
 * Check to see whether the Euclidean distance between the leaf point and the
 * query point is less than the distance threshold. If it is, then this visitor
 * will return an {@link java.util.Optional} containing this leaf point
 * (converted to a {@link Neighbor} object). Otherwise, this visitor will return
 * an empty Optional.
 *
 * @param leafNode    the leaf node being visited
 * @param depthOfNode the depth of the leaf node
 */
@Override
public void acceptLeaf(INodeView leafNode, int depthOfNode) {
    float[] leafPoint = leafNode.getLiftedLeafPoint();
    double distanceSquared = 0.0;
    for (int i = 0; i < leafPoint.length; i++) {
        double diff = queryPoint[i] - leafPoint[i];
        distanceSquared += diff * diff;
    }
    if (Math.sqrt(distanceSquared) < distanceThreshold) {
        List<Long> sequenceIndexes = new ArrayList<>(leafNode.getSequenceIndexes().keySet());
        neighbor = new Neighbor(toDoubleArray(leafPoint), Math.sqrt(distanceSquared), sequenceIndexes);
    }
}
Also used : ArrayList(java.util.ArrayList) Neighbor(com.amazon.randomcutforest.returntypes.Neighbor)

Example 5 with Neighbor

use of com.amazon.randomcutforest.returntypes.Neighbor in project random-cut-forest-by-aws by aws.

the class NearNeighborVisitorTest method acceptLeafNear.

@Test
public void acceptLeafNear() {
    float[] leafPoint = new float[] { 8.8f, 9.9f, -5.5f };
    INodeView leafNode = mock(NodeView.class);
    when(leafNode.getLeafPoint()).thenReturn(Arrays.copyOf(leafPoint, leafPoint.length));
    when(leafNode.getLiftedLeafPoint()).thenReturn(Arrays.copyOf(leafPoint, leafPoint.length));
    HashMap<Long, Integer> sequenceIndexes = new HashMap<>();
    sequenceIndexes.put(1234L, 1);
    sequenceIndexes.put(5678L, 1);
    when(leafNode.getSequenceIndexes()).thenReturn(sequenceIndexes);
    int depth = 12;
    visitor.acceptLeaf(leafNode, depth);
    Optional<Neighbor> optional = visitor.getResult();
    assertTrue(optional.isPresent());
    Neighbor neighbor = optional.get();
    assertNotSame(leafPoint, neighbor.point);
    assertArrayEquals(toDoubleArray(leafPoint), neighbor.point);
    assertEquals(Math.sqrt(3 * 1.1 * 1.1), neighbor.distance, EPSILON);
    assertNotSame(leafNode.getSequenceIndexes(), neighbor.sequenceIndexes);
}
Also used : HashMap(java.util.HashMap) Neighbor(com.amazon.randomcutforest.returntypes.Neighbor) INodeView(com.amazon.randomcutforest.tree.INodeView) Test(org.junit.jupiter.api.Test)

Aggregations

Neighbor (com.amazon.randomcutforest.returntypes.Neighbor)6 Test (org.junit.jupiter.api.Test)5 INodeView (com.amazon.randomcutforest.tree.INodeView)3 ArrayList (java.util.ArrayList)2 HashMap (java.util.HashMap)2