use of com.amazon.randomcutforest.returntypes.Neighbor in project random-cut-forest-by-aws by aws.
the class DynamicPointSetFunctionalTest method movingNeighbors.
@Test
public void movingNeighbors() {
int newDimensions = 2;
randomSeed = 123;
RandomCutForest newForest = RandomCutForest.builder().dimensions(newDimensions).randomSeed(randomSeed).timeDecay(1.0 / 800).centerOfMassEnabled(true).storeSequenceIndexesEnabled(true).build();
double[][] data = generateFan(1000, 3);
double[] queryPoint = new double[] { 0.7, 0 };
for (int degree = 0; degree < 360; degree += 2) {
for (int j = 0; j < data.length; j++) {
newForest.update(rotateClockWise(data[j], 2 * PI * degree / 360));
}
List<Neighbor> ans = newForest.getNearNeighborsInSample(queryPoint, 1);
List<Neighbor> closeNeighBors = newForest.getNearNeighborsInSample(queryPoint, 0.1);
Neighbor best = null;
if (ans != null) {
best = ans.get(0);
for (int j = 1; j < ans.size(); j++) {
assert (ans.get(j).distance >= best.distance);
}
}
// fan is away at 30, 150 and 270
if (((degree > 15) && (degree < 45)) || ((degree >= 135) && (degree <= 165)) || ((degree >= 255) && (degree <= 285))) {
// no close neighbor
assertTrue(closeNeighBors.size() == 0);
assertTrue(best.distance > 0.3);
}
// fan is overhead at 90, 210 and 330
if (((degree > 75) && (degree < 105)) || ((degree >= 195) && (degree <= 225)) || ((degree >= 315) && (degree <= 345))) {
assertTrue(closeNeighBors.size() > 0);
assertEquals(closeNeighBors.get(0).distance, best.distance, 1E-10);
}
}
}
use of com.amazon.randomcutforest.returntypes.Neighbor in project random-cut-forest-by-aws by aws.
the class RandomCutForestTest method testGetNearNeighborInSample.
@Test
public void testGetNearNeighborInSample() {
List<Long> indexes1 = new ArrayList<>();
indexes1.add(1L);
indexes1.add(3L);
List<Long> indexes2 = new ArrayList<>();
indexes2.add(2L);
indexes2.add(4L);
List<Long> indexes4 = new ArrayList<>();
indexes4.add(1L);
indexes4.add(3L);
List<Long> indexes5 = new ArrayList<>();
indexes5.add(2L);
indexes5.add(4L);
Neighbor neighbor1 = new Neighbor(new double[] { 1, 2 }, 5, indexes1);
when(((SamplerPlusTree<?, ?>) components.get(0)).getTree().traverse(any(float[].class), any(IVisitorFactory.class))).thenReturn(Optional.of(neighbor1));
Neighbor neighbor2 = new Neighbor(new double[] { 1, 2 }, 5, indexes2);
when(((SamplerPlusTree<?, ?>) components.get(1)).getTree().traverse(any(float[].class), any(IVisitorFactory.class))).thenReturn(Optional.of(neighbor2));
when(((SamplerPlusTree<?, ?>) components.get(2)).getTree().traverse(any(float[].class), any(IVisitorFactory.class))).thenReturn(Optional.empty());
Neighbor neighbor4 = new Neighbor(new double[] { 2, 3 }, 4, indexes4);
when(((SamplerPlusTree<?, ?>) components.get(3)).getTree().traverse(any(float[].class), any(IVisitorFactory.class))).thenReturn(Optional.of(neighbor4));
Neighbor neighbor5 = new Neighbor(new double[] { 2, 3 }, 4, indexes5);
when(((SamplerPlusTree<?, ?>) components.get(4)).getTree().traverse(any(float[].class), any(IVisitorFactory.class))).thenReturn(Optional.of(neighbor5));
for (int i = 5; i < components.size(); i++) {
when(((SamplerPlusTree<?, ?>) components.get(i)).getTree().traverse(any(float[].class), any(IVisitorFactory.class))).thenReturn(Optional.empty());
}
Whitebox.setInternalState(forest, "storeSequenceIndexesEnabled", true);
doReturn(true).when(forest).isOutputReady();
List<Neighbor> neighbors = forest.getNearNeighborsInSample(new double[] { 0, 0 }, 5);
List<Long> expectedIndexes = Arrays.asList(1L, 2L, 3L, 4L);
assertEquals(2, neighbors.size());
assertTrue(neighbors.get(0).point[0] == 2 && neighbors.get(0).point[1] == 3);
assertEquals(4, neighbors.get(0).distance);
assertEquals(4, neighbors.get(0).sequenceIndexes.size());
assertThat(neighbors.get(0).sequenceIndexes, is(expectedIndexes));
assertTrue(neighbors.get(1).point[0] == 1 && neighbors.get(1).point[1] == 2);
assertEquals(5, neighbors.get(1).distance);
assertEquals(4, neighbors.get(1).sequenceIndexes.size());
assertThat(neighbors.get(1).sequenceIndexes, is(expectedIndexes));
}
use of com.amazon.randomcutforest.returntypes.Neighbor in project random-cut-forest-by-aws by aws.
the class NearNeighborVisitorTest method acceptLeafNotNear.
@Test
public void acceptLeafNotNear() {
float[] leafPoint = new float[] { 108.8f, 209.9f, -305.5f };
INodeView leafNode = mock(NodeView.class);
HashMap<Long, Integer> sequenceIndexes = new HashMap<>();
sequenceIndexes.put(1234L, 1);
sequenceIndexes.put(5678L, 1);
when(leafNode.getLeafPoint()).thenReturn(leafPoint);
when(leafNode.getLiftedLeafPoint()).thenReturn(leafPoint);
when(leafNode.getSequenceIndexes()).thenReturn(sequenceIndexes);
int depth = 12;
visitor.acceptLeaf(leafNode, depth);
Optional<Neighbor> optional = visitor.getResult();
assertFalse(optional.isPresent());
}
use of com.amazon.randomcutforest.returntypes.Neighbor in project random-cut-forest-by-aws by aws.
the class NearNeighborVisitor method acceptLeaf.
/**
* Check to see whether the Euclidean distance between the leaf point and the
* query point is less than the distance threshold. If it is, then this visitor
* will return an {@link java.util.Optional} containing this leaf point
* (converted to a {@link Neighbor} object). Otherwise, this visitor will return
* an empty Optional.
*
* @param leafNode the leaf node being visited
* @param depthOfNode the depth of the leaf node
*/
@Override
public void acceptLeaf(INodeView leafNode, int depthOfNode) {
float[] leafPoint = leafNode.getLiftedLeafPoint();
double distanceSquared = 0.0;
for (int i = 0; i < leafPoint.length; i++) {
double diff = queryPoint[i] - leafPoint[i];
distanceSquared += diff * diff;
}
if (Math.sqrt(distanceSquared) < distanceThreshold) {
List<Long> sequenceIndexes = new ArrayList<>(leafNode.getSequenceIndexes().keySet());
neighbor = new Neighbor(toDoubleArray(leafPoint), Math.sqrt(distanceSquared), sequenceIndexes);
}
}
use of com.amazon.randomcutforest.returntypes.Neighbor in project random-cut-forest-by-aws by aws.
the class NearNeighborVisitorTest method acceptLeafNear.
@Test
public void acceptLeafNear() {
float[] leafPoint = new float[] { 8.8f, 9.9f, -5.5f };
INodeView leafNode = mock(NodeView.class);
when(leafNode.getLeafPoint()).thenReturn(Arrays.copyOf(leafPoint, leafPoint.length));
when(leafNode.getLiftedLeafPoint()).thenReturn(Arrays.copyOf(leafPoint, leafPoint.length));
HashMap<Long, Integer> sequenceIndexes = new HashMap<>();
sequenceIndexes.put(1234L, 1);
sequenceIndexes.put(5678L, 1);
when(leafNode.getSequenceIndexes()).thenReturn(sequenceIndexes);
int depth = 12;
visitor.acceptLeaf(leafNode, depth);
Optional<Neighbor> optional = visitor.getResult();
assertTrue(optional.isPresent());
Neighbor neighbor = optional.get();
assertNotSame(leafPoint, neighbor.point);
assertArrayEquals(toDoubleArray(leafPoint), neighbor.point);
assertEquals(Math.sqrt(3 * 1.1 * 1.1), neighbor.distance, EPSILON);
assertNotSame(leafNode.getSequenceIndexes(), neighbor.sequenceIndexes);
}