use of org.opensearch.ad.ml.ThresholdingResult in project anomaly-detection by opensearch-project.
the class TestHelpers method randomThresholdingResults.
public static List<ThresholdingResult> randomThresholdingResults() {
double grade = 1.;
double confidence = 0.5;
double score = 1.;
ThresholdingResult thresholdingResult = new ThresholdingResult(grade, confidence, score);
List<ThresholdingResult> results = new ArrayList<>();
results.add(thresholdingResult);
return results;
}
use of org.opensearch.ad.ml.ThresholdingResult in project anomaly-detection by opensearch-project.
the class ThresholdResultTests method testNormal.
@SuppressWarnings("unchecked")
public void testNormal() {
TransportService transportService = new TransportService(Settings.EMPTY, mock(Transport.class), null, TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, Collections.emptySet());
ModelManager manager = mock(ModelManager.class);
ThresholdResultTransportAction action = new ThresholdResultTransportAction(mock(ActionFilters.class), transportService, manager);
doAnswer(invocation -> {
ActionListener<ThresholdingResult> listener = invocation.getArgument(3);
listener.onResponse(new ThresholdingResult(0, 1.0d, 0.2));
return null;
}).when(manager).getThresholdingResult(any(String.class), any(String.class), anyDouble(), any(ActionListener.class));
final PlainActionFuture<ThresholdResultResponse> future = new PlainActionFuture<>();
ThresholdResultRequest request = new ThresholdResultRequest("123", "123-threshold", 2);
action.doExecute(mock(Task.class), request, future);
ThresholdResultResponse response = future.actionGet();
assertEquals(0, response.getAnomalyGrade(), 0.001);
assertEquals(1, response.getConfidence(), 0.001);
}
use of org.opensearch.ad.ml.ThresholdingResult in project anomaly-detection by opensearch-project.
the class AnomalyResultTests method setUp.
@SuppressWarnings("unchecked")
@Override
@Before
public void setUp() throws Exception {
super.setUp();
super.setUpLog4jForJUnit(AnomalyResultTransportAction.class);
setupTestNodes(AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, AnomalyDetectorSettings.PAGE_SIZE);
transportService = testNodes[0].transportService;
clusterService = testNodes[0].clusterService;
settings = clusterService.getSettings();
stateManager = mock(NodeStateManager.class);
when(stateManager.isMuted(any(String.class), any(String.class))).thenReturn(false);
when(stateManager.markColdStartRunning(anyString())).thenReturn(() -> {
});
detector = mock(AnomalyDetector.class);
featureId = "xyz";
// we have one feature
when(detector.getEnabledFeatureIds()).thenReturn(Collections.singletonList(featureId));
featureName = "abc";
when(detector.getEnabledFeatureNames()).thenReturn(Collections.singletonList(featureName));
List<String> userIndex = new ArrayList<>();
userIndex.add("test*");
when(detector.getIndices()).thenReturn(userIndex);
adID = "123";
when(detector.getDetectorId()).thenReturn(adID);
when(detector.getCategoryField()).thenReturn(null);
doAnswer(invocation -> {
ActionListener<Optional<AnomalyDetector>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(detector));
return null;
}).when(stateManager).getAnomalyDetector(any(String.class), any(ActionListener.class));
when(detector.getDetectorIntervalInMinutes()).thenReturn(1L);
hashRing = mock(HashRing.class);
Optional<DiscoveryNode> localNode = Optional.of(clusterService.state().nodes().getLocalNode());
when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(localNode);
doReturn(localNode).when(hashRing).getNodeByAddress(any());
featureQuery = mock(FeatureManager.class);
doAnswer(invocation -> {
ActionListener<SinglePointFeatures> listener = invocation.getArgument(3);
listener.onResponse(new SinglePointFeatures(Optional.of(new double[] { 0.0d }), Optional.of(new double[] { 0 })));
return null;
}).when(featureQuery).getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class));
double rcfScore = 0.2;
confidence = 0.91;
anomalyGrade = 0.5;
normalModelManager = mock(ModelManager.class);
long totalUpdates = 1440;
int relativeIndex = 0;
double[] currentTimeAttribution = new double[] { 0.5, 0.5 };
double[] pastValues = new double[] { 123, 456 };
double[][] expectedValuesList = new double[][] { new double[] { 789, 12 } };
double[] likelihood = new double[] { 1 };
double threshold = 1.1d;
doAnswer(invocation -> {
ActionListener<ThresholdingResult> listener = invocation.getArgument(3);
listener.onResponse(new ThresholdingResult(anomalyGrade, confidence, rcfScore, totalUpdates, relativeIndex, currentTimeAttribution, pastValues, expectedValuesList, likelihood, threshold, 30));
return null;
}).when(normalModelManager).getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<ThresholdingResult> listener = invocation.getArgument(3);
listener.onResponse(new ThresholdingResult(0, 1.0d, rcfScore));
return null;
}).when(normalModelManager).getThresholdingResult(any(String.class), any(String.class), anyDouble(), any(ActionListener.class));
// "123-threshold";
thresholdModelID = SingleStreamModelIdMapper.getThresholdModelId(adID);
// when(normalModelPartitioner.getThresholdModelId(any(String.class))).thenReturn(thresholdModelID);
adCircuitBreakerService = mock(ADCircuitBreakerService.class);
when(adCircuitBreakerService.isOpen()).thenReturn(false);
ThreadPool threadPool = mock(ThreadPool.class);
client = mock(Client.class);
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
when(client.threadPool()).thenReturn(threadPool);
when(client.threadPool().getThreadContext()).thenReturn(threadContext);
doAnswer(invocation -> {
Object[] args = invocation.getArguments();
assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length >= 2);
IndexRequest request = null;
ActionListener<IndexResponse> listener = null;
if (args[0] instanceof IndexRequest) {
request = (IndexRequest) args[0];
}
if (args[1] instanceof ActionListener) {
listener = (ActionListener<IndexResponse>) args[1];
}
assertTrue(request != null && listener != null);
ShardId shardId = new ShardId(new Index(CommonName.ANOMALY_RESULT_INDEX_ALIAS, randomAlphaOfLength(10)), 0);
listener.onResponse(new IndexResponse(shardId, randomAlphaOfLength(10), request.id(), 1, 1, 1, true));
return null;
}).when(client).index(any(), any());
indexNameResolver = new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY));
Map<String, ADStat<?>> statsMap = new HashMap<String, ADStat<?>>() {
{
put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier()));
put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier()));
put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier()));
put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier()));
}
};
adStats = new ADStats(statsMap);
doAnswer(invocation -> {
Object[] args = invocation.getArguments();
GetRequest request = (GetRequest) args[0];
ActionListener<GetResponse> listener = (ActionListener<GetResponse>) args[1];
if (request.index().equals(CommonName.DETECTION_STATE_INDEX)) {
DetectorInternalState.Builder result = new DetectorInternalState.Builder().lastUpdateTime(Instant.now());
listener.onResponse(TestHelpers.createGetResponse(result.build(), detector.getDetectorId(), CommonName.DETECTION_STATE_INDEX));
}
return null;
}).when(client).get(any(), any());
adTaskManager = mock(ADTaskManager.class);
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(3);
listener.onResponse(true);
return null;
}).when(adTaskManager).initRealtimeTaskCacheAndCleanupStaleCache(anyString(), any(AnomalyDetector.class), any(TransportService.class), any(ActionListener.class));
}
use of org.opensearch.ad.ml.ThresholdingResult in project anomaly-detection by opensearch-project.
the class AnomalyResultTests method testEndRunDueToNoTrainingData.
@SuppressWarnings("unchecked")
public void testEndRunDueToNoTrainingData() {
ThreadPool mockThreadPool = mock(ThreadPool.class);
setUpColdStart(mockThreadPool, new ColdStartConfig.Builder().coldStartRunning(false).build());
ModelManager rcfManager = mock(ModelManager.class);
doAnswer(invocation -> {
Object[] args = invocation.getArguments();
ActionListener<ThresholdingResult> listener = (ActionListener<ThresholdingResult>) args[3];
listener.onFailure(new IndexNotFoundException(CommonName.CHECKPOINT_INDEX_NAME));
return null;
}).when(rcfManager).getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class));
when(stateManager.fetchExceptionAndClear(any(String.class))).thenReturn(Optional.of(new EndRunException(adID, "Cannot get training data", false)));
doAnswer(invocation -> {
ActionListener<Optional<double[][]>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(new double[][] { { 1.0 } }));
return null;
}).when(featureQuery).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class));
doAnswer(invocation -> {
ActionListener<Optional<Void>> listener = invocation.getArgument(2);
listener.onResponse(null);
return null;
}).when(normalModelManager).trainModel(any(AnomalyDetector.class), any(double[][].class), any(ActionListener.class));
// These constructors register handler in transport service
new RCFResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, rcfManager, adCircuitBreakerService, hashRing);
new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager);
AnomalyResultTransportAction action = new AnomalyResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, settings, client, stateManager, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, adCircuitBreakerService, adStats, mockThreadPool, NamedXContentRegistry.EMPTY, adTaskManager);
AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200);
PlainActionFuture<AnomalyResultResponse> listener = new PlainActionFuture<>();
action.doExecute(null, request, listener);
assertException(listener, EndRunException.class);
verify(stateManager, times(1)).markColdStartRunning(eq(adID));
}
use of org.opensearch.ad.ml.ThresholdingResult in project anomaly-detection by opensearch-project.
the class RCFResultTests method testNormal.
@SuppressWarnings("unchecked")
public void testNormal() {
TransportService transportService = new TransportService(Settings.EMPTY, mock(Transport.class), null, TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, Collections.emptySet());
ModelManager manager = mock(ModelManager.class);
ADCircuitBreakerService adCircuitBreakerService = mock(ADCircuitBreakerService.class);
RCFResultTransportAction action = new RCFResultTransportAction(mock(ActionFilters.class), transportService, manager, adCircuitBreakerService, hashRing);
double rcfScore = 0.5;
int forestSize = 25;
doAnswer(invocation -> {
ActionListener<ThresholdingResult> listener = invocation.getArgument(3);
listener.onResponse(new ThresholdingResult(grade, 0d, rcfScore, totalUpdates, 0, attribution, pastValues, expectedValuesList, likelihood, threshold, forestSize));
return null;
}).when(manager).getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class));
when(adCircuitBreakerService.isOpen()).thenReturn(false);
final PlainActionFuture<RCFResultResponse> future = new PlainActionFuture<>();
RCFResultRequest request = new RCFResultRequest("123", "123-rcf-1", new double[] { 0 });
action.doExecute(mock(Task.class), request, future);
RCFResultResponse response = future.actionGet();
assertEquals(rcfScore, response.getRCFScore(), 0.001);
assertEquals(forestSize, response.getForestSize(), 0.001);
assertTrue(Arrays.equals(attribution, response.getAttribution()));
}
Aggregations