use of org.opensearch.ad.util.ClientUtil in project anomaly-detection by opensearch-project.
the class CheckpointDaoTests method test_getModelCheckpoint_returnExpectedToListener.
@SuppressWarnings("unchecked")
public void test_getModelCheckpoint_returnExpectedToListener() {
// ArgumentCaptor<GetRequest> requestCaptor = ArgumentCaptor.forClass(GetRequest.class);
UpdateResponse updateResponse = new UpdateResponse(new ReplicationResponse.ShardInfo(3, 2), new ShardId(CommonName.CHECKPOINT_INDEX_NAME, "uuid", 2), CommonName.CHECKPOINT_INDEX_NAME, "1", 7, 17, 2, UPDATED);
AtomicReference<GetRequest> getRequest = new AtomicReference<>();
doAnswer(invocation -> {
ActionRequest request = invocation.getArgument(0);
if (request instanceof GetRequest) {
getRequest.set((GetRequest) request);
ActionListener<GetResponse> listener = invocation.getArgument(2);
listener.onResponse(getResponse);
} else {
UpdateRequest updateRequest = (UpdateRequest) request;
when(getResponse.getSource()).thenReturn(updateRequest.doc().sourceAsMap());
ActionListener<UpdateResponse> listener = invocation.getArgument(2);
listener.onResponse(updateResponse);
}
return null;
}).when(clientUtil).asyncRequest(any(), any(BiConsumer.class), any(ActionListener.class));
when(getResponse.isExists()).thenReturn(true);
ThresholdedRandomCutForest trcf = createTRCF();
final CountDownLatch inProgressLatch = new CountDownLatch(1);
checkpointDao.putTRCFCheckpoint(modelId, trcf, ActionListener.wrap(response -> {
inProgressLatch.countDown();
}, exception -> {
assertTrue("Should not reach here ", false);
inProgressLatch.countDown();
}));
ActionListener<Optional<ThresholdedRandomCutForest>> listener = mock(ActionListener.class);
checkpointDao.getTRCFModel(modelId, listener);
GetRequest capturedGetRequest = getRequest.get();
assertEquals(indexName, capturedGetRequest.index());
assertEquals(modelId, capturedGetRequest.id());
ArgumentCaptor<Optional<ThresholdedRandomCutForest>> responseCaptor = ArgumentCaptor.forClass(Optional.class);
verify(listener).onResponse(responseCaptor.capture());
Optional<ThresholdedRandomCutForest> result = responseCaptor.getValue();
assertTrue(result.isPresent());
RandomCutForest deserializedForest = result.get().getForest();
RandomCutForest serializedForest = trcf.getForest();
assertEquals(deserializedForest.getDimensions(), serializedForest.getDimensions());
assertEquals(deserializedForest.getNumberOfTrees(), serializedForest.getNumberOfTrees());
assertEquals(deserializedForest.getSampleSize(), serializedForest.getSampleSize());
}
use of org.opensearch.ad.util.ClientUtil in project anomaly-detection by opensearch-project.
the class CheckpointDaoTests method test_getModelCheckpoint_Bwc.
@SuppressWarnings("unchecked")
public void test_getModelCheckpoint_Bwc() {
// ArgumentCaptor<GetRequest> requestCaptor = ArgumentCaptor.forClass(GetRequest.class);
UpdateResponse updateResponse = new UpdateResponse(new ReplicationResponse.ShardInfo(3, 2), new ShardId(CommonName.CHECKPOINT_INDEX_NAME, "uuid", 2), CommonName.CHECKPOINT_INDEX_NAME, "1", 7, 17, 2, UPDATED);
AtomicReference<GetRequest> getRequest = new AtomicReference<>();
doAnswer(invocation -> {
ActionRequest request = invocation.getArgument(0);
if (request instanceof GetRequest) {
getRequest.set((GetRequest) request);
ActionListener<GetResponse> listener = invocation.getArgument(2);
listener.onResponse(getResponse);
} else {
UpdateRequest updateRequest = (UpdateRequest) request;
when(getResponse.getSource()).thenReturn(updateRequest.doc().sourceAsMap());
ActionListener<UpdateResponse> listener = invocation.getArgument(2);
listener.onResponse(updateResponse);
}
return null;
}).when(clientUtil).asyncRequest(any(), any(BiConsumer.class), any(ActionListener.class));
when(getResponse.isExists()).thenReturn(true);
ThresholdedRandomCutForest trcf = createTRCF();
final CountDownLatch inProgressLatch = new CountDownLatch(1);
checkpointDao.putTRCFCheckpoint(modelId, trcf, ActionListener.wrap(response -> {
inProgressLatch.countDown();
}, exception -> {
assertTrue("Should not reach here ", false);
inProgressLatch.countDown();
}));
ActionListener<Optional<ThresholdedRandomCutForest>> listener = mock(ActionListener.class);
checkpointDao.getTRCFModel(modelId, listener);
GetRequest capturedGetRequest = getRequest.get();
assertEquals(indexName, capturedGetRequest.index());
assertEquals(modelId, capturedGetRequest.id());
ArgumentCaptor<Optional<ThresholdedRandomCutForest>> responseCaptor = ArgumentCaptor.forClass(Optional.class);
verify(listener).onResponse(responseCaptor.capture());
Optional<ThresholdedRandomCutForest> result = responseCaptor.getValue();
assertTrue(result.isPresent());
RandomCutForest deserializedForest = result.get().getForest();
RandomCutForest serializedForest = trcf.getForest();
assertEquals(deserializedForest.getDimensions(), serializedForest.getDimensions());
assertEquals(deserializedForest.getNumberOfTrees(), serializedForest.getNumberOfTrees());
assertEquals(deserializedForest.getSampleSize(), serializedForest.getSampleSize());
}
use of org.opensearch.ad.util.ClientUtil in project anomaly-detection by opensearch-project.
the class CheckpointDaoTests method test_batch_read.
@SuppressWarnings("unchecked")
public void test_batch_read() throws InterruptedException {
doAnswer(invocation -> {
ActionListener<MultiGetResponse> listener = invocation.getArgument(2);
MultiGetItemResponse[] items = new MultiGetItemResponse[1];
items[0] = new MultiGetItemResponse(null, new MultiGetResponse.Failure(CommonName.CHECKPOINT_INDEX_NAME, "_doc", "modelId", new IndexNotFoundException(CommonName.CHECKPOINT_INDEX_NAME)));
listener.onResponse(new MultiGetResponse(items));
return null;
}).when(clientUtil).execute(eq(MultiGetAction.INSTANCE), any(MultiGetRequest.class), any(ActionListener.class));
final CountDownLatch processingLatch = new CountDownLatch(1);
checkpointDao.batchRead(new MultiGetRequest(), ActionListener.wrap(response -> processingLatch.countDown(), e -> {
assertTrue(false);
}));
// we don't expect the waiting time elapsed before the count reached zero
assertTrue(processingLatch.await(100, TimeUnit.SECONDS));
verify(clientUtil, times(1)).execute(any(), any(), any());
}
use of org.opensearch.ad.util.ClientUtil in project anomaly-detection by opensearch-project.
the class CheckpointDao method saveModelCheckpointAsync.
/**
* Update the model doc using fields in source. This ensures we won't touch
* the old checkpoint and nodes with old/new logic can coexist in a cluster.
* This is useful for introducing compact rcf new model format.
*
* @param source fields to update
* @param modelId model Id, used as doc id in the checkpoint index
* @param listener Listener to return response
*/
private void saveModelCheckpointAsync(Map<String, Object> source, String modelId, ActionListener<Void> listener) {
UpdateRequest updateRequest = new UpdateRequest(indexName, modelId);
updateRequest.doc(source);
// If the document does not already exist, the contents of the upsert element are inserted as a new document.
// If the document exists, update fields in the map
updateRequest.docAsUpsert(true);
clientUtil.<UpdateRequest, UpdateResponse>asyncRequest(updateRequest, client::update, ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure));
}
use of org.opensearch.ad.util.ClientUtil in project anomaly-detection by opensearch-project.
the class SearchFeatureDao method getLatestDataTime.
/**
* Returns epoch time of the latest data under the detector.
*
* @deprecated use getLatestDataTime with listener instead.
*
* @param detector info about the indices and documents
* @return epoch time of the latest data in milliseconds
*/
@Deprecated
public Optional<Long> getLatestDataTime(AnomalyDetector detector) {
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().aggregation(AggregationBuilders.max(CommonName.AGG_NAME_MAX_TIME).field(detector.getTimeField())).size(0);
SearchRequest searchRequest = new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder);
return clientUtil.<SearchRequest, SearchResponse>timedRequest(searchRequest, logger, client::search).map(SearchResponse::getAggregations).map(aggs -> aggs.asMap()).map(map -> (Max) map.get(CommonName.AGG_NAME_MAX_TIME)).map(agg -> (long) agg.getValue());
}
Aggregations