Search in sources :

Example 6 with ThresholdingResult

use of org.opensearch.ad.ml.ThresholdingResult in project anomaly-detection by opensearch-project.

the class TestHelpers method randomThresholdingResults.

public static List<ThresholdingResult> randomThresholdingResults() {
    double grade = 1.;
    double confidence = 0.5;
    double score = 1.;
    ThresholdingResult thresholdingResult = new ThresholdingResult(grade, confidence, score);
    List<ThresholdingResult> results = new ArrayList<>();
    results.add(thresholdingResult);
    return results;
}
Also used : ArrayList(java.util.ArrayList) ThresholdingResult(org.opensearch.ad.ml.ThresholdingResult)

Example 7 with ThresholdingResult

use of org.opensearch.ad.ml.ThresholdingResult in project anomaly-detection by opensearch-project.

the class ThresholdResultTests method testNormal.

@SuppressWarnings("unchecked")
public void testNormal() {
    TransportService transportService = new TransportService(Settings.EMPTY, mock(Transport.class), null, TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, Collections.emptySet());
    ModelManager manager = mock(ModelManager.class);
    ThresholdResultTransportAction action = new ThresholdResultTransportAction(mock(ActionFilters.class), transportService, manager);
    doAnswer(invocation -> {
        ActionListener<ThresholdingResult> listener = invocation.getArgument(3);
        listener.onResponse(new ThresholdingResult(0, 1.0d, 0.2));
        return null;
    }).when(manager).getThresholdingResult(any(String.class), any(String.class), anyDouble(), any(ActionListener.class));
    final PlainActionFuture<ThresholdResultResponse> future = new PlainActionFuture<>();
    ThresholdResultRequest request = new ThresholdResultRequest("123", "123-threshold", 2);
    action.doExecute(mock(Task.class), request, future);
    ThresholdResultResponse response = future.actionGet();
    assertEquals(0, response.getAnomalyGrade(), 0.001);
    assertEquals(1, response.getConfidence(), 0.001);
}
Also used : Task(org.opensearch.tasks.Task) ActionFilters(org.opensearch.action.support.ActionFilters) ModelManager(org.opensearch.ad.ml.ModelManager) ThresholdingResult(org.opensearch.ad.ml.ThresholdingResult) ActionListener(org.opensearch.action.ActionListener) TransportService(org.opensearch.transport.TransportService) PlainActionFuture(org.opensearch.action.support.PlainActionFuture) Transport(org.opensearch.transport.Transport)

Example 8 with ThresholdingResult

use of org.opensearch.ad.ml.ThresholdingResult in project anomaly-detection by opensearch-project.

the class AnomalyResultTests method setUp.

@SuppressWarnings("unchecked")
@Override
@Before
public void setUp() throws Exception {
    super.setUp();
    super.setUpLog4jForJUnit(AnomalyResultTransportAction.class);
    setupTestNodes(AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, AnomalyDetectorSettings.PAGE_SIZE);
    transportService = testNodes[0].transportService;
    clusterService = testNodes[0].clusterService;
    settings = clusterService.getSettings();
    stateManager = mock(NodeStateManager.class);
    when(stateManager.isMuted(any(String.class), any(String.class))).thenReturn(false);
    when(stateManager.markColdStartRunning(anyString())).thenReturn(() -> {
    });
    detector = mock(AnomalyDetector.class);
    featureId = "xyz";
    // we have one feature
    when(detector.getEnabledFeatureIds()).thenReturn(Collections.singletonList(featureId));
    featureName = "abc";
    when(detector.getEnabledFeatureNames()).thenReturn(Collections.singletonList(featureName));
    List<String> userIndex = new ArrayList<>();
    userIndex.add("test*");
    when(detector.getIndices()).thenReturn(userIndex);
    adID = "123";
    when(detector.getDetectorId()).thenReturn(adID);
    when(detector.getCategoryField()).thenReturn(null);
    doAnswer(invocation -> {
        ActionListener<Optional<AnomalyDetector>> listener = invocation.getArgument(1);
        listener.onResponse(Optional.of(detector));
        return null;
    }).when(stateManager).getAnomalyDetector(any(String.class), any(ActionListener.class));
    when(detector.getDetectorIntervalInMinutes()).thenReturn(1L);
    hashRing = mock(HashRing.class);
    Optional<DiscoveryNode> localNode = Optional.of(clusterService.state().nodes().getLocalNode());
    when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(localNode);
    doReturn(localNode).when(hashRing).getNodeByAddress(any());
    featureQuery = mock(FeatureManager.class);
    doAnswer(invocation -> {
        ActionListener<SinglePointFeatures> listener = invocation.getArgument(3);
        listener.onResponse(new SinglePointFeatures(Optional.of(new double[] { 0.0d }), Optional.of(new double[] { 0 })));
        return null;
    }).when(featureQuery).getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class));
    double rcfScore = 0.2;
    confidence = 0.91;
    anomalyGrade = 0.5;
    normalModelManager = mock(ModelManager.class);
    long totalUpdates = 1440;
    int relativeIndex = 0;
    double[] currentTimeAttribution = new double[] { 0.5, 0.5 };
    double[] pastValues = new double[] { 123, 456 };
    double[][] expectedValuesList = new double[][] { new double[] { 789, 12 } };
    double[] likelihood = new double[] { 1 };
    double threshold = 1.1d;
    doAnswer(invocation -> {
        ActionListener<ThresholdingResult> listener = invocation.getArgument(3);
        listener.onResponse(new ThresholdingResult(anomalyGrade, confidence, rcfScore, totalUpdates, relativeIndex, currentTimeAttribution, pastValues, expectedValuesList, likelihood, threshold, 30));
        return null;
    }).when(normalModelManager).getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class));
    doAnswer(invocation -> {
        ActionListener<ThresholdingResult> listener = invocation.getArgument(3);
        listener.onResponse(new ThresholdingResult(0, 1.0d, rcfScore));
        return null;
    }).when(normalModelManager).getThresholdingResult(any(String.class), any(String.class), anyDouble(), any(ActionListener.class));
    // "123-threshold";
    thresholdModelID = SingleStreamModelIdMapper.getThresholdModelId(adID);
    // when(normalModelPartitioner.getThresholdModelId(any(String.class))).thenReturn(thresholdModelID);
    adCircuitBreakerService = mock(ADCircuitBreakerService.class);
    when(adCircuitBreakerService.isOpen()).thenReturn(false);
    ThreadPool threadPool = mock(ThreadPool.class);
    client = mock(Client.class);
    ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
    when(client.threadPool()).thenReturn(threadPool);
    when(client.threadPool().getThreadContext()).thenReturn(threadContext);
    doAnswer(invocation -> {
        Object[] args = invocation.getArguments();
        assertTrue(String.format("The size of args is %d.  Its content is %s", args.length, Arrays.toString(args)), args.length >= 2);
        IndexRequest request = null;
        ActionListener<IndexResponse> listener = null;
        if (args[0] instanceof IndexRequest) {
            request = (IndexRequest) args[0];
        }
        if (args[1] instanceof ActionListener) {
            listener = (ActionListener<IndexResponse>) args[1];
        }
        assertTrue(request != null && listener != null);
        ShardId shardId = new ShardId(new Index(CommonName.ANOMALY_RESULT_INDEX_ALIAS, randomAlphaOfLength(10)), 0);
        listener.onResponse(new IndexResponse(shardId, randomAlphaOfLength(10), request.id(), 1, 1, 1, true));
        return null;
    }).when(client).index(any(), any());
    indexNameResolver = new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY));
    Map<String, ADStat<?>> statsMap = new HashMap<String, ADStat<?>>() {

        {
            put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier()));
            put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier()));
            put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier()));
            put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier()));
        }
    };
    adStats = new ADStats(statsMap);
    doAnswer(invocation -> {
        Object[] args = invocation.getArguments();
        GetRequest request = (GetRequest) args[0];
        ActionListener<GetResponse> listener = (ActionListener<GetResponse>) args[1];
        if (request.index().equals(CommonName.DETECTION_STATE_INDEX)) {
            DetectorInternalState.Builder result = new DetectorInternalState.Builder().lastUpdateTime(Instant.now());
            listener.onResponse(TestHelpers.createGetResponse(result.build(), detector.getDetectorId(), CommonName.DETECTION_STATE_INDEX));
        }
        return null;
    }).when(client).get(any(), any());
    adTaskManager = mock(ADTaskManager.class);
    doAnswer(invocation -> {
        ActionListener<Boolean> listener = invocation.getArgument(3);
        listener.onResponse(true);
        return null;
    }).when(adTaskManager).initRealtimeTaskCacheAndCleanupStaleCache(anyString(), any(AnomalyDetector.class), any(TransportService.class), any(ActionListener.class));
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Index(org.opensearch.index.Index) Matchers.containsString(org.hamcrest.Matchers.containsString) Mockito.anyString(org.mockito.Mockito.anyString) ArgumentMatchers.anyString(org.mockito.ArgumentMatchers.anyString) AnomalyDetector(org.opensearch.ad.model.AnomalyDetector) GetRequest(org.opensearch.action.get.GetRequest) Client(org.opensearch.client.Client) OpenSearchTestCase.randomBoolean(org.opensearch.test.OpenSearchTestCase.randomBoolean) Optional(java.util.Optional) ModelManager(org.opensearch.ad.ml.ModelManager) ThresholdingResult(org.opensearch.ad.ml.ThresholdingResult) ActionListener(org.opensearch.action.ActionListener) IndexResponse(org.opensearch.action.index.IndexResponse) IndexNameExpressionResolver(org.opensearch.cluster.metadata.IndexNameExpressionResolver) DiscoveryNode(org.opensearch.cluster.node.DiscoveryNode) ADStat(org.opensearch.ad.stats.ADStat) ThreadPool(org.opensearch.threadpool.ThreadPool) DetectorInternalState(org.opensearch.ad.model.DetectorInternalState) IndexRequest(org.opensearch.action.index.IndexRequest) NodeStateManager(org.opensearch.ad.NodeStateManager) HashRing(org.opensearch.ad.cluster.HashRing) ShardId(org.opensearch.index.shard.ShardId) CounterSupplier(org.opensearch.ad.stats.suppliers.CounterSupplier) SinglePointFeatures(org.opensearch.ad.feature.SinglePointFeatures) FeatureManager(org.opensearch.ad.feature.FeatureManager) ADCircuitBreakerService(org.opensearch.ad.breaker.ADCircuitBreakerService) ThreadContext(org.opensearch.common.util.concurrent.ThreadContext) GetResponse(org.opensearch.action.get.GetResponse) ADTaskManager(org.opensearch.ad.task.ADTaskManager) TransportService(org.opensearch.transport.TransportService) ADStats(org.opensearch.ad.stats.ADStats) Before(org.junit.Before)

Example 9 with ThresholdingResult

use of org.opensearch.ad.ml.ThresholdingResult in project anomaly-detection by opensearch-project.

the class AnomalyResultTests method testEndRunDueToNoTrainingData.

@SuppressWarnings("unchecked")
public void testEndRunDueToNoTrainingData() {
    ThreadPool mockThreadPool = mock(ThreadPool.class);
    setUpColdStart(mockThreadPool, new ColdStartConfig.Builder().coldStartRunning(false).build());
    ModelManager rcfManager = mock(ModelManager.class);
    doAnswer(invocation -> {
        Object[] args = invocation.getArguments();
        ActionListener<ThresholdingResult> listener = (ActionListener<ThresholdingResult>) args[3];
        listener.onFailure(new IndexNotFoundException(CommonName.CHECKPOINT_INDEX_NAME));
        return null;
    }).when(rcfManager).getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class));
    when(stateManager.fetchExceptionAndClear(any(String.class))).thenReturn(Optional.of(new EndRunException(adID, "Cannot get training data", false)));
    doAnswer(invocation -> {
        ActionListener<Optional<double[][]>> listener = invocation.getArgument(1);
        listener.onResponse(Optional.of(new double[][] { { 1.0 } }));
        return null;
    }).when(featureQuery).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class));
    doAnswer(invocation -> {
        ActionListener<Optional<Void>> listener = invocation.getArgument(2);
        listener.onResponse(null);
        return null;
    }).when(normalModelManager).trainModel(any(AnomalyDetector.class), any(double[][].class), any(ActionListener.class));
    // These constructors register handler in transport service
    new RCFResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, rcfManager, adCircuitBreakerService, hashRing);
    new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager);
    AnomalyResultTransportAction action = new AnomalyResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, settings, client, stateManager, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, adCircuitBreakerService, adStats, mockThreadPool, NamedXContentRegistry.EMPTY, adTaskManager);
    AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200);
    PlainActionFuture<AnomalyResultResponse> listener = new PlainActionFuture<>();
    action.doExecute(null, request, listener);
    assertException(listener, EndRunException.class);
    verify(stateManager, times(1)).markColdStartRunning(eq(adID));
}
Also used : EndRunException(org.opensearch.ad.common.exception.EndRunException) Optional(java.util.Optional) ThreadPool(org.opensearch.threadpool.ThreadPool) Matchers.containsString(org.hamcrest.Matchers.containsString) Mockito.anyString(org.mockito.Mockito.anyString) ArgumentMatchers.anyString(org.mockito.ArgumentMatchers.anyString) ActionFilters(org.opensearch.action.support.ActionFilters) ModelManager(org.opensearch.ad.ml.ModelManager) AnomalyDetector(org.opensearch.ad.model.AnomalyDetector) ThresholdingResult(org.opensearch.ad.ml.ThresholdingResult) ActionListener(org.opensearch.action.ActionListener) PlainActionFuture(org.opensearch.action.support.PlainActionFuture) IndexNotFoundException(org.opensearch.index.IndexNotFoundException)

Example 10 with ThresholdingResult

use of org.opensearch.ad.ml.ThresholdingResult in project anomaly-detection by opensearch-project.

the class RCFResultTests method testNormal.

@SuppressWarnings("unchecked")
public void testNormal() {
    TransportService transportService = new TransportService(Settings.EMPTY, mock(Transport.class), null, TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, Collections.emptySet());
    ModelManager manager = mock(ModelManager.class);
    ADCircuitBreakerService adCircuitBreakerService = mock(ADCircuitBreakerService.class);
    RCFResultTransportAction action = new RCFResultTransportAction(mock(ActionFilters.class), transportService, manager, adCircuitBreakerService, hashRing);
    double rcfScore = 0.5;
    int forestSize = 25;
    doAnswer(invocation -> {
        ActionListener<ThresholdingResult> listener = invocation.getArgument(3);
        listener.onResponse(new ThresholdingResult(grade, 0d, rcfScore, totalUpdates, 0, attribution, pastValues, expectedValuesList, likelihood, threshold, forestSize));
        return null;
    }).when(manager).getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class));
    when(adCircuitBreakerService.isOpen()).thenReturn(false);
    final PlainActionFuture<RCFResultResponse> future = new PlainActionFuture<>();
    RCFResultRequest request = new RCFResultRequest("123", "123-rcf-1", new double[] { 0 });
    action.doExecute(mock(Task.class), request, future);
    RCFResultResponse response = future.actionGet();
    assertEquals(rcfScore, response.getRCFScore(), 0.001);
    assertEquals(forestSize, response.getForestSize(), 0.001);
    assertTrue(Arrays.equals(attribution, response.getAttribution()));
}
Also used : Task(org.opensearch.tasks.Task) ADCircuitBreakerService(org.opensearch.ad.breaker.ADCircuitBreakerService) ActionFilters(org.opensearch.action.support.ActionFilters) ModelManager(org.opensearch.ad.ml.ModelManager) ThresholdingResult(org.opensearch.ad.ml.ThresholdingResult) ActionListener(org.opensearch.action.ActionListener) TransportService(org.opensearch.transport.TransportService) PlainActionFuture(org.opensearch.action.support.PlainActionFuture) Transport(org.opensearch.transport.Transport)

Aggregations

ThresholdingResult (org.opensearch.ad.ml.ThresholdingResult)12 ModelManager (org.opensearch.ad.ml.ModelManager)7 ActionListener (org.opensearch.action.ActionListener)6 ArrayList (java.util.ArrayList)5 ActionFilters (org.opensearch.action.support.ActionFilters)5 AnomalyDetector (org.opensearch.ad.model.AnomalyDetector)5 PlainActionFuture (org.opensearch.action.support.PlainActionFuture)4 AnomalyResult (org.opensearch.ad.model.AnomalyResult)4 TransportService (org.opensearch.transport.TransportService)4 Instant (java.time.Instant)3 Map (java.util.Map)3 ADCircuitBreakerService (org.opensearch.ad.breaker.ADCircuitBreakerService)3 Task (org.opensearch.tasks.Task)3 Transport (org.opensearch.transport.Transport)3 HashMap (java.util.HashMap)2 List (java.util.List)2 Optional (java.util.Optional)2 ParameterizedMessage (org.apache.logging.log4j.message.ParameterizedMessage)2 Matchers.containsString (org.hamcrest.Matchers.containsString)2 ArgumentMatchers.anyString (org.mockito.ArgumentMatchers.anyString)2