use of org.opensearch.transport.TransportService in project anomaly-detection by opensearch-project.
the class EntityResultTransportActionTests method setUp.
@SuppressWarnings("unchecked")
@Override
@Before
public void setUp() throws Exception {
super.setUp();
actionFilters = mock(ActionFilters.class);
transportService = mock(TransportService.class);
adCircuitBreakerService = mock(ADCircuitBreakerService.class);
when(adCircuitBreakerService.isOpen()).thenReturn(false);
checkpointDao = mock(CheckpointDao.class);
detectorId = "123";
entities = new HashMap<>();
start = 10L;
end = 20L;
request = new EntityResultRequest(detectorId, entities, start, end);
clock = mock(Clock.class);
now = Instant.now();
when(clock.instant()).thenReturn(now);
manager = new ModelManager(null, clock, 0, 0, 0, 0, 0, 0, null, null, mock(EntityColdStarter.class), null, null);
provider = mock(CacheProvider.class);
entityCache = mock(EntityCache.class);
when(provider.get()).thenReturn(entityCache);
String field = "a";
detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field));
stateManager = mock(NodeStateManager.class);
doAnswer(invocation -> {
ActionListener<Optional<AnomalyDetector>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(detector));
return null;
}).when(stateManager).getAnomalyDetector(any(String.class), any(ActionListener.class));
cacheMissEntity = "0.0.0.1";
cacheMissData = new double[] { 0.1 };
cacheHitEntity = "0.0.0.2";
cacheHitData = new double[] { 0.2 };
cacheMissEntityObj = Entity.createSingleAttributeEntity(detector.getCategoryField().get(0), cacheMissEntity);
entities.put(cacheMissEntityObj, cacheMissData);
cacheHitEntityObj = Entity.createSingleAttributeEntity(detector.getCategoryField().get(0), cacheHitEntity);
entities.put(cacheHitEntityObj, cacheHitData);
tooLongEntity = randomAlphaOfLength(AnomalyDetectorSettings.MAX_ENTITY_LENGTH + 1);
tooLongData = new double[] { 0.3 };
entities.put(Entity.createSingleAttributeEntity(detector.getCategoryField().get(0), tooLongEntity), tooLongData);
ModelState<EntityModel> state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build());
when(entityCache.get(eq(cacheMissEntityObj.getModelId(detectorId).get()), any())).thenReturn(null);
when(entityCache.get(eq(cacheHitEntityObj.getModelId(detectorId).get()), any())).thenReturn(state);
List<Entity> coldEntities = new ArrayList<>();
coldEntities.add(cacheMissEntityObj);
when(entityCache.selectUpdateCandidate(any(), anyString(), any())).thenReturn(Pair.of(new ArrayList<>(), coldEntities));
settings = Settings.builder().put(AnomalyDetectorSettings.COOLDOWN_MINUTES.getKey(), TimeValue.timeValueMinutes(5)).build();
AnomalyDetectionIndices indexUtil = mock(AnomalyDetectionIndices.class);
when(indexUtil.getSchemaVersion(any())).thenReturn(CommonValue.NO_SCHEMA_VERSION);
resultWriteQueue = mock(ResultWriteWorker.class);
checkpointReadQueue = mock(CheckpointReadWorker.class);
minSamples = 1;
coldStarter = mock(EntityColdStarter.class);
doAnswer(invocation -> {
ModelState<EntityModel> modelState = invocation.getArgument(0);
modelState.getModel().clear();
return null;
}).when(coldStarter).trainModelFromExistingSamples(any(), anyInt());
coldEntityQueue = mock(ColdEntityWorker.class);
entityResult = new EntityResultTransportAction(actionFilters, transportService, manager, adCircuitBreakerService, provider, stateManager, indexUtil, resultWriteQueue, checkpointReadQueue, coldEntityQueue, threadPool);
// timeout in 60 seconds
timeoutMs = 60000L;
}
use of org.opensearch.transport.TransportService in project anomaly-detection by opensearch-project.
the class RCFPollingTests method setUp.
@Override
@Before
public void setUp() throws Exception {
super.setUp();
clusterService = mock(ClusterService.class);
hashRing = mock(HashRing.class);
transportAddress1 = new TransportAddress(new InetSocketAddress(InetAddress.getByName("1.2.3.4"), 9300));
manager = mock(ModelManager.class);
transportService = new TransportService(Settings.EMPTY, mock(Transport.class), null, TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, Collections.emptySet());
future = new PlainActionFuture<>();
request = new RCFPollingRequest(detectorId);
model0Id = SingleStreamModelIdMapper.getRcfModelId(detectorId, 0);
doAnswer(invocation -> {
Object[] args = invocation.getArguments();
@SuppressWarnings("unchecked") ActionListener<Long> listener = (ActionListener<Long>) args[2];
listener.onResponse(totalUpdates);
return null;
}).when(manager).getTotalUpdates(any(String.class), any(String.class), any());
normalTransportInterceptor = new TransportInterceptor() {
@Override
public AsyncSender interceptSender(AsyncSender sender) {
return new AsyncSender() {
@Override
public <T extends TransportResponse> void sendRequest(Transport.Connection connection, String action, TransportRequest request, TransportRequestOptions options, TransportResponseHandler<T> handler) {
if (RCFPollingAction.NAME.equals(action)) {
sender.sendRequest(connection, action, request, options, rcfRollingHandler(handler));
} else {
sender.sendRequest(connection, action, request, options, handler);
}
}
};
}
};
failureTransportInterceptor = new TransportInterceptor() {
@Override
public AsyncSender interceptSender(AsyncSender sender) {
return new AsyncSender() {
@Override
public <T extends TransportResponse> void sendRequest(Transport.Connection connection, String action, TransportRequest request, TransportRequestOptions options, TransportResponseHandler<T> handler) {
if (RCFPollingAction.NAME.equals(action)) {
sender.sendRequest(connection, action, request, options, rcfFailureRollingHandler(handler));
} else {
sender.sendRequest(connection, action, request, options, handler);
}
}
};
}
};
}
use of org.opensearch.transport.TransportService in project anomaly-detection by opensearch-project.
the class RCFPollingTests method testGetRemoteNormalResponse.
public void testGetRemoteNormalResponse() {
setupTestNodes(normalTransportInterceptor, Settings.EMPTY);
try {
TransportService realTransportService = testNodes[0].transportService;
clusterService = testNodes[0].clusterService;
action = new RCFPollingTransportAction(new ActionFilters(Collections.emptySet()), realTransportService, Settings.EMPTY, manager, hashRing, clusterService);
when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(Optional.of(testNodes[1].discoveryNode()));
registerHandler(testNodes[1]);
action.doExecute(null, request, future);
RCFPollingResponse response = future.actionGet();
assertEquals(totalUpdates, response.getTotalUpdates());
} finally {
tearDownTestNodes();
}
}
use of org.opensearch.transport.TransportService in project anomaly-detection by opensearch-project.
the class RCFResultTests method testExecutionException.
@SuppressWarnings("unchecked")
public void testExecutionException() {
TransportService transportService = new TransportService(Settings.EMPTY, mock(Transport.class), null, TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, Collections.emptySet());
ModelManager manager = mock(ModelManager.class);
ADCircuitBreakerService adCircuitBreakerService = mock(ADCircuitBreakerService.class);
RCFResultTransportAction action = new RCFResultTransportAction(mock(ActionFilters.class), transportService, manager, adCircuitBreakerService, hashRing);
doThrow(NullPointerException.class).when(manager).getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class));
when(adCircuitBreakerService.isOpen()).thenReturn(false);
final PlainActionFuture<RCFResultResponse> future = new PlainActionFuture<>();
RCFResultRequest request = new RCFResultRequest("123", "123-rcf-1", new double[] { 0 });
action.doExecute(mock(Task.class), request, future);
expectThrows(NullPointerException.class, () -> future.actionGet());
}
use of org.opensearch.transport.TransportService in project anomaly-detection by opensearch-project.
the class RCFResultTests method testCircuitBreaker.
@SuppressWarnings("unchecked")
public void testCircuitBreaker() {
TransportService transportService = new TransportService(Settings.EMPTY, mock(Transport.class), null, TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, Collections.emptySet());
ModelManager manager = mock(ModelManager.class);
ADCircuitBreakerService breakerService = mock(ADCircuitBreakerService.class);
RCFResultTransportAction action = new RCFResultTransportAction(mock(ActionFilters.class), transportService, manager, breakerService, hashRing);
doAnswer(invocation -> {
ActionListener<ThresholdingResult> listener = invocation.getArgument(3);
listener.onResponse(new ThresholdingResult(grade, 0d, 0.5, totalUpdates, 0, attribution, pastValues, expectedValuesList, likelihood, threshold, 30));
return null;
}).when(manager).getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class));
when(breakerService.isOpen()).thenReturn(true);
final PlainActionFuture<RCFResultResponse> future = new PlainActionFuture<>();
RCFResultRequest request = new RCFResultRequest("123", "123-rcf-1", new double[] { 0 });
action.doExecute(mock(Task.class), request, future);
expectThrows(LimitExceededException.class, () -> future.actionGet());
}
Aggregations