use of org.opensearch.cluster.service.ClusterService in project k-NN by opensearch-project.
the class VectorReaderTests method testRead_valid_incompleteIndex.
public void testRead_valid_incompleteIndex() throws InterruptedException, ExecutionException, IOException {
// Check if we get the right number of vectors if the index contains docs that are missing fields
// Create an index with knn disabled
String indexName = "test-index";
String fieldName = "test-field";
int dim = 16;
int numVectors = 100;
createIndex(indexName);
// Add a field mapping to the index
createKnnIndexMapping(indexName, fieldName, dim);
// Create list of random vectors and ingest
Random random = new Random();
List<Float[]> vectors = new ArrayList<>();
for (int i = 0; i < numVectors; i++) {
Float[] vector = new Float[dim];
for (int j = 0; j < dim; j++) {
vector[j] = random.nextFloat();
}
vectors.add(vector);
addKnnDoc(indexName, Integer.toString(i), fieldName, vector);
}
// Create documents that do not have fieldName for training
int docsWithoutKNN = 100;
String fieldNameWithoutKnn = "test-field-2";
for (int i = 0; i < docsWithoutKNN; i++) {
addDoc(indexName, Integer.toString(i + numVectors), fieldNameWithoutKnn, "dummyValue");
}
// Configure VectorReader
ClusterService clusterService = node().injector().getInstance(ClusterService.class);
VectorReader vectorReader = new VectorReader(client());
// Read all vectors and confirm they match vectors
TestVectorConsumer testVectorConsumer = new TestVectorConsumer();
final CountDownLatch inProgressLatch1 = new CountDownLatch(1);
vectorReader.read(clusterService, indexName, fieldName, 10000, 10, testVectorConsumer, ActionListener.wrap(response -> inProgressLatch1.countDown(), e -> fail(e.toString())));
assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS));
List<Float[]> consumedVectors = testVectorConsumer.getVectorsConsumed();
assertEquals(numVectors, consumedVectors.size());
List<Float> flatVectors = vectors.stream().flatMap(Arrays::stream).collect(Collectors.toList());
List<Float> flatConsumedVectors = consumedVectors.stream().flatMap(Arrays::stream).collect(Collectors.toList());
assertEquals(new HashSet<>(flatVectors), new HashSet<>(flatConsumedVectors));
}
use of org.opensearch.cluster.service.ClusterService in project k-NN by opensearch-project.
the class VectorReaderTests method testRead_invalid_searchSize.
public void testRead_invalid_searchSize() {
// Create the index
String indexName = "test-index";
String fieldName = "test-field";
int dim = 16;
createIndex(indexName);
// Add a field mapping to the index
createKnnIndexMapping(indexName, fieldName, dim);
// Configure VectorReader
ClusterService clusterService = node().injector().getInstance(ClusterService.class);
VectorReader vectorReader = new VectorReader(client());
// Search size is negative
expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 100, -10, null, null));
// Search size is greater than 10000
expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 100, 20000, null, null));
}
use of org.opensearch.cluster.service.ClusterService in project k-NN by opensearch-project.
the class VectorReaderTests method testRead_valid_OnlyGetMaxVectors.
public void testRead_valid_OnlyGetMaxVectors() throws InterruptedException, ExecutionException, IOException {
// Check if we can limit the number of docs via max operation
// Create an index with knn disabled
String indexName = "test-index";
String fieldName = "test-field";
int dim = 16;
int numVectorsIndex = 100;
int maxNumVectorsRead = 20;
createIndex(indexName);
// Add a field mapping to the index
createKnnIndexMapping(indexName, fieldName, dim);
// Create list of random vectors and ingest
Random random = new Random();
for (int i = 0; i < numVectorsIndex; i++) {
Float[] vector = new Float[dim];
for (int j = 0; j < dim; j++) {
vector[j] = random.nextFloat();
}
addKnnDoc(indexName, Integer.toString(i), fieldName, vector);
}
// Configure VectorReader
ClusterService clusterService = node().injector().getInstance(ClusterService.class);
VectorReader vectorReader = new VectorReader(client());
// Read maxNumVectorsRead vectors
TestVectorConsumer testVectorConsumer = new TestVectorConsumer();
final CountDownLatch inProgressLatch1 = new CountDownLatch(1);
vectorReader.read(clusterService, indexName, fieldName, maxNumVectorsRead, 10, testVectorConsumer, ActionListener.wrap(response -> inProgressLatch1.countDown(), e -> fail(e.toString())));
assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS));
List<Float[]> consumedVectors = testVectorConsumer.getVectorsConsumed();
assertEquals(maxNumVectorsRead, consumedVectors.size());
}
use of org.opensearch.cluster.service.ClusterService in project k-NN by opensearch-project.
the class VectorReaderTests method testRead_invalid_indexDoesNotExist.
public void testRead_invalid_indexDoesNotExist() {
// Check that read throws a validation exception when the index does not exist
String indexName = "test-index";
String fieldName = "test-field";
// Configure VectorReader
ClusterService clusterService = node().injector().getInstance(ClusterService.class);
VectorReader vectorReader = new VectorReader(client());
// Should throw a validation exception because index does not exist
expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 10000, 10, null, null));
}
use of org.opensearch.cluster.service.ClusterService in project anomaly-detection by opensearch-project.
the class NoPowermockSearchFeatureDaoTests method testGetHighestCountEntitiesExhaustedPages.
@SuppressWarnings("unchecked")
public void testGetHighestCountEntitiesExhaustedPages() throws InterruptedException {
SearchResponse response1 = createPageResponse(attrs1);
CompositeAggregation emptyComposite = mock(CompositeAggregation.class);
when(emptyComposite.getName()).thenReturn(SearchFeatureDao.AGG_NAME_TOP);
when(emptyComposite.afterKey()).thenReturn(null);
// empty bucket
when(emptyComposite.getBuckets()).thenAnswer((Answer<List<CompositeAggregation.Bucket>>) invocation -> {
return new ArrayList<CompositeAggregation.Bucket>();
});
Aggregations emptyAggs = new Aggregations(Collections.singletonList(emptyComposite));
SearchResponseSections emptySections = new SearchResponseSections(SearchHits.empty(), emptyAggs, null, false, null, null, 1);
SearchResponse emptyResponse = new SearchResponse(emptySections, null, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY, Clusters.EMPTY);
CountDownLatch inProgress = new CountDownLatch(2);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
inProgress.countDown();
if (inProgress.getCount() == 1) {
listener.onResponse(response1);
} else {
listener.onResponse(emptyResponse);
}
return null;
}).when(client).search(any(), any());
ActionListener<List<Entity>> listener = mock(ActionListener.class);
searchFeatureDao = new SearchFeatureDao(client, xContentRegistry(), interpolator, clientUtil, settings, clusterService, AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, clock, 2, 1, 60_000L);
searchFeatureDao.getHighestCountEntities(detector, 10L, 20L, listener);
ArgumentCaptor<List<Entity>> captor = ArgumentCaptor.forClass(List.class);
verify(listener).onResponse(captor.capture());
List<Entity> result = captor.getValue();
assertEquals(1, result.size());
assertEquals(Entity.createEntityByReordering(attrs1), result.get(0));
// both counts are used in client.search
assertTrue(inProgress.await(10000L, TimeUnit.MILLISECONDS));
}
Aggregations