Search in sources :

Example 16 with EntityModel

use of org.opensearch.ad.ml.EntityModel in project anomaly-detection by opensearch-project.

the class EntityColdStartWorker method executeRequest.

@Override
protected void executeRequest(EntityRequest coldStartRequest, ActionListener<Void> listener) {
    String detectorId = coldStartRequest.getDetectorId();
    Optional<String> modelId = coldStartRequest.getModelId();
    if (false == modelId.isPresent()) {
        String error = String.format(Locale.ROOT, "Fail to get model id for request %s", coldStartRequest);
        LOG.warn(error);
        listener.onFailure(new RuntimeException(error));
        return;
    }
    ModelState<EntityModel> modelState = new ModelState<>(new EntityModel(coldStartRequest.getEntity(), new ArrayDeque<>(), null), modelId.get(), detectorId, ModelType.ENTITY.getName(), clock, 0);
    ActionListener<Void> failureListener = ActionListener.delegateResponse(listener, (delegateListener, e) -> {
        if (ExceptionUtil.isOverloaded(e)) {
            LOG.error("OpenSearch is overloaded");
            setCoolDownStart();
        }
        nodeStateManager.setException(detectorId, e);
        delegateListener.onFailure(e);
    });
    entityColdStarter.trainModel(coldStartRequest.getEntity(), detectorId, modelState, failureListener);
}
Also used : EntityModel(org.opensearch.ad.ml.EntityModel) ModelState(org.opensearch.ad.ml.ModelState) ArrayDeque(java.util.ArrayDeque)

Example 17 with EntityModel

use of org.opensearch.ad.ml.EntityModel in project anomaly-detection by opensearch-project.

the class EntityResultTransportActionTests method setUp.

@SuppressWarnings("unchecked")
@Override
@Before
public void setUp() throws Exception {
    super.setUp();
    actionFilters = mock(ActionFilters.class);
    transportService = mock(TransportService.class);
    adCircuitBreakerService = mock(ADCircuitBreakerService.class);
    when(adCircuitBreakerService.isOpen()).thenReturn(false);
    checkpointDao = mock(CheckpointDao.class);
    detectorId = "123";
    entities = new HashMap<>();
    start = 10L;
    end = 20L;
    request = new EntityResultRequest(detectorId, entities, start, end);
    clock = mock(Clock.class);
    now = Instant.now();
    when(clock.instant()).thenReturn(now);
    manager = new ModelManager(null, clock, 0, 0, 0, 0, 0, 0, null, null, mock(EntityColdStarter.class), null, null);
    provider = mock(CacheProvider.class);
    entityCache = mock(EntityCache.class);
    when(provider.get()).thenReturn(entityCache);
    String field = "a";
    detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field));
    stateManager = mock(NodeStateManager.class);
    doAnswer(invocation -> {
        ActionListener<Optional<AnomalyDetector>> listener = invocation.getArgument(1);
        listener.onResponse(Optional.of(detector));
        return null;
    }).when(stateManager).getAnomalyDetector(any(String.class), any(ActionListener.class));
    cacheMissEntity = "0.0.0.1";
    cacheMissData = new double[] { 0.1 };
    cacheHitEntity = "0.0.0.2";
    cacheHitData = new double[] { 0.2 };
    cacheMissEntityObj = Entity.createSingleAttributeEntity(detector.getCategoryField().get(0), cacheMissEntity);
    entities.put(cacheMissEntityObj, cacheMissData);
    cacheHitEntityObj = Entity.createSingleAttributeEntity(detector.getCategoryField().get(0), cacheHitEntity);
    entities.put(cacheHitEntityObj, cacheHitData);
    tooLongEntity = randomAlphaOfLength(AnomalyDetectorSettings.MAX_ENTITY_LENGTH + 1);
    tooLongData = new double[] { 0.3 };
    entities.put(Entity.createSingleAttributeEntity(detector.getCategoryField().get(0), tooLongEntity), tooLongData);
    ModelState<EntityModel> state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build());
    when(entityCache.get(eq(cacheMissEntityObj.getModelId(detectorId).get()), any())).thenReturn(null);
    when(entityCache.get(eq(cacheHitEntityObj.getModelId(detectorId).get()), any())).thenReturn(state);
    List<Entity> coldEntities = new ArrayList<>();
    coldEntities.add(cacheMissEntityObj);
    when(entityCache.selectUpdateCandidate(any(), anyString(), any())).thenReturn(Pair.of(new ArrayList<>(), coldEntities));
    settings = Settings.builder().put(AnomalyDetectorSettings.COOLDOWN_MINUTES.getKey(), TimeValue.timeValueMinutes(5)).build();
    AnomalyDetectionIndices indexUtil = mock(AnomalyDetectionIndices.class);
    when(indexUtil.getSchemaVersion(any())).thenReturn(CommonValue.NO_SCHEMA_VERSION);
    resultWriteQueue = mock(ResultWriteWorker.class);
    checkpointReadQueue = mock(CheckpointReadWorker.class);
    minSamples = 1;
    coldStarter = mock(EntityColdStarter.class);
    doAnswer(invocation -> {
        ModelState<EntityModel> modelState = invocation.getArgument(0);
        modelState.getModel().clear();
        return null;
    }).when(coldStarter).trainModelFromExistingSamples(any(), anyInt());
    coldEntityQueue = mock(ColdEntityWorker.class);
    entityResult = new EntityResultTransportAction(actionFilters, transportService, manager, adCircuitBreakerService, provider, stateManager, indexUtil, resultWriteQueue, checkpointReadQueue, coldEntityQueue, threadPool);
    // timeout in 60 seconds
    timeoutMs = 60000L;
}
Also used : Entity(org.opensearch.ad.model.Entity) EntityCache(org.opensearch.ad.caching.EntityCache) ArrayList(java.util.ArrayList) CheckpointReadWorker(org.opensearch.ad.ratelimit.CheckpointReadWorker) ArgumentMatchers.anyString(org.mockito.ArgumentMatchers.anyString) Clock(java.time.Clock) ResultWriteWorker(org.opensearch.ad.ratelimit.ResultWriteWorker) NodeStateManager(org.opensearch.ad.NodeStateManager) RandomModelStateConfig(test.org.opensearch.ad.util.RandomModelStateConfig) EntityColdStarter(org.opensearch.ad.ml.EntityColdStarter) ColdEntityWorker(org.opensearch.ad.ratelimit.ColdEntityWorker) ADCircuitBreakerService(org.opensearch.ad.breaker.ADCircuitBreakerService) Optional(java.util.Optional) EntityModel(org.opensearch.ad.ml.EntityModel) ActionFilters(org.opensearch.action.support.ActionFilters) ModelManager(org.opensearch.ad.ml.ModelManager) CacheProvider(org.opensearch.ad.caching.CacheProvider) CheckpointDao(org.opensearch.ad.ml.CheckpointDao) ActionListener(org.opensearch.action.ActionListener) TransportService(org.opensearch.transport.TransportService) AnomalyDetectionIndices(org.opensearch.ad.indices.AnomalyDetectionIndices) Before(org.junit.Before)

Example 18 with EntityModel

use of org.opensearch.ad.ml.EntityModel in project anomaly-detection by opensearch-project.

the class EntityResultTransportAction method onGetDetector.

private ActionListener<Optional<AnomalyDetector>> onGetDetector(ActionListener<AcknowledgedResponse> listener, String detectorId, EntityResultRequest request, Optional<Exception> prevException) {
    return ActionListener.wrap(detectorOptional -> {
        if (!detectorOptional.isPresent()) {
            listener.onFailure(new EndRunException(detectorId, "AnomalyDetector is not available.", true));
            return;
        }
        AnomalyDetector detector = detectorOptional.get();
        if (request.getEntities() == null) {
            listener.onResponse(null);
            return;
        }
        Instant executionStartTime = Instant.now();
        Map<Entity, double[]> cacheMissEntities = new HashMap<>();
        for (Entry<Entity, double[]> entityEntry : request.getEntities().entrySet()) {
            Entity categoricalValues = entityEntry.getKey();
            if (isEntityeFromOldNodeMsg(categoricalValues) && detector.getCategoryField() != null && detector.getCategoryField().size() == 1) {
                Map<String, String> attrValues = categoricalValues.getAttributes();
                // handle a request from a version before OpenSearch 1.1.
                categoricalValues = Entity.createSingleAttributeEntity(detector.getCategoryField().get(0), attrValues.get(CommonName.EMPTY_FIELD));
            }
            Optional<String> modelIdOptional = categoricalValues.getModelId(detectorId);
            if (false == modelIdOptional.isPresent()) {
                continue;
            }
            String modelId = modelIdOptional.get();
            double[] datapoint = entityEntry.getValue();
            ModelState<EntityModel> entityModel = cache.get().get(modelId, detector);
            if (entityModel == null) {
                // cache miss
                cacheMissEntities.put(categoricalValues, datapoint);
                continue;
            }
            ThresholdingResult result = modelManager.getAnomalyResultForEntity(datapoint, entityModel, modelId, categoricalValues, detector.getShingleSize());
            // So many OpenSearchRejectedExecutionException if we write no matter what
            if (result.getRcfScore() > 0) {
                AnomalyResult resultToSave = result.toAnomalyResult(detector, Instant.ofEpochMilli(request.getStart()), Instant.ofEpochMilli(request.getEnd()), executionStartTime, Instant.now(), ParseUtils.getFeatureData(datapoint, detector), categoricalValues, indexUtil.getSchemaVersion(ADIndex.RESULT), modelId, null, null);
                resultWriteQueue.put(new ResultWriteRequest(System.currentTimeMillis() + detector.getDetectorIntervalInMilliseconds(), detectorId, result.getGrade() > 0 ? RequestPriority.HIGH : RequestPriority.MEDIUM, resultToSave, detector.getResultIndex()));
            }
        }
        // split hot and cold entities
        Pair<List<Entity>, List<Entity>> hotColdEntities = cache.get().selectUpdateCandidate(cacheMissEntities.keySet(), detectorId, detector);
        List<EntityFeatureRequest> hotEntityRequests = new ArrayList<>();
        List<EntityFeatureRequest> coldEntityRequests = new ArrayList<>();
        for (Entity hotEntity : hotColdEntities.getLeft()) {
            double[] hotEntityValue = cacheMissEntities.get(hotEntity);
            if (hotEntityValue == null) {
                LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", hotEntity));
                continue;
            }
            hotEntityRequests.add(new EntityFeatureRequest(System.currentTimeMillis() + detector.getDetectorIntervalInMilliseconds(), detectorId, // hot entities has MEDIUM priority
            RequestPriority.MEDIUM, hotEntity, hotEntityValue, request.getStart()));
        }
        for (Entity coldEntity : hotColdEntities.getRight()) {
            double[] coldEntityValue = cacheMissEntities.get(coldEntity);
            if (coldEntityValue == null) {
                LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", coldEntity));
                continue;
            }
            coldEntityRequests.add(new EntityFeatureRequest(System.currentTimeMillis() + detector.getDetectorIntervalInMilliseconds(), detectorId, // cold entities has LOW priority
            RequestPriority.LOW, coldEntity, coldEntityValue, request.getStart()));
        }
        checkpointReadQueue.putAll(hotEntityRequests);
        coldEntityQueue.putAll(coldEntityRequests);
        // respond back
        if (prevException.isPresent()) {
            listener.onFailure(prevException.get());
        } else {
            listener.onResponse(new AcknowledgedResponse(true));
        }
    }, exception -> {
        LOG.error(new ParameterizedMessage("fail to get entity's anomaly grade for detector [{}]: start: [{}], end: [{}]", detectorId, request.getStart(), request.getEnd()), exception);
        listener.onFailure(exception);
    });
}
Also used : Entity(org.opensearch.ad.model.Entity) EndRunException(org.opensearch.ad.common.exception.EndRunException) EntityFeatureRequest(org.opensearch.ad.ratelimit.EntityFeatureRequest) HashMap(java.util.HashMap) Instant(java.time.Instant) EntityModel(org.opensearch.ad.ml.EntityModel) ArrayList(java.util.ArrayList) AcknowledgedResponse(org.opensearch.action.support.master.AcknowledgedResponse) AnomalyDetector(org.opensearch.ad.model.AnomalyDetector) ThresholdingResult(org.opensearch.ad.ml.ThresholdingResult) AnomalyResult(org.opensearch.ad.model.AnomalyResult) ArrayList(java.util.ArrayList) List(java.util.List) ParameterizedMessage(org.apache.logging.log4j.message.ParameterizedMessage) ResultWriteRequest(org.opensearch.ad.ratelimit.ResultWriteRequest)

Example 19 with EntityModel

use of org.opensearch.ad.ml.EntityModel in project anomaly-detection by opensearch-project.

the class PriorityCacheTests method testCacheHit.

public void testCacheHit() {
    // 800 MB is the limit
    long largeHeapSize = 800_000_000;
    JvmInfo info = mock(JvmInfo.class);
    Mem mem = mock(Mem.class);
    when(info.getMem()).thenReturn(mem);
    when(mem.getHeapMax()).thenReturn(new ByteSizeValue(largeHeapSize));
    JvmService jvmService = mock(JvmService.class);
    when(jvmService.info()).thenReturn(info);
    // ClusterService clusterService = mock(ClusterService.class);
    float modelMaxPercen = 0.1f;
    // Settings settings = Settings.builder().put(AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.getKey(), modelMaxPercen).build();
    // ClusterSettings clusterSettings = new ClusterSettings(
    // settings,
    // Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE)))
    // );
    // when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
    memoryTracker = spy(new MemoryTracker(jvmService, modelMaxPercen, 0.002, clusterService, mock(ADCircuitBreakerService.class)));
    EntityCache cache = new PriorityCache(checkpoint, dedicatedCacheSize, AnomalyDetectorSettings.CHECKPOINT_TTL, AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES, memoryTracker, AnomalyDetectorSettings.NUM_TREES, clock, clusterService, AnomalyDetectorSettings.HOURLY_MAINTENANCE, threadPool, checkpointWriteQueue, AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT);
    cacheProvider = new CacheProvider(cache).get();
    // cache miss due to door keeper
    assertEquals(null, cacheProvider.get(modelState1.getModelId(), detector));
    // cache miss due to empty cache
    assertEquals(null, cacheProvider.get(modelState1.getModelId(), detector));
    cacheProvider.hostIfPossible(detector, modelState1);
    assertEquals(1, cacheProvider.getTotalActiveEntities());
    assertEquals(1, cacheProvider.getAllModels().size());
    ModelState<EntityModel> hitState = cacheProvider.get(modelState1.getModelId(), detector);
    assertEquals(detectorId, hitState.getDetectorId());
    EntityModel model = hitState.getModel();
    assertEquals(false, model.getTrcf().isPresent());
    assertTrue(model.getSamples().isEmpty());
    modelState1.getModel().addSample(point);
    assertTrue(Arrays.equals(point, model.getSamples().peek()));
    ArgumentCaptor<Long> memoryConsumed = ArgumentCaptor.forClass(Long.class);
    ArgumentCaptor<Boolean> reserved = ArgumentCaptor.forClass(Boolean.class);
    ArgumentCaptor<MemoryTracker.Origin> origin = ArgumentCaptor.forClass(MemoryTracker.Origin.class);
    // input dimension: 3, shingle: 4
    long expectedMemoryPerEntity = 436828L;
    verify(memoryTracker, times(1)).consumeMemory(memoryConsumed.capture(), reserved.capture(), origin.capture());
    assertEquals(dedicatedCacheSize * expectedMemoryPerEntity, memoryConsumed.getValue().intValue());
    assertEquals(true, reserved.getValue().booleanValue());
    assertEquals(MemoryTracker.Origin.HC_DETECTOR, origin.getValue());
// for (int i = 0; i < 2; i++) {
// cacheProvider.get(modelId2, detector);
// }
}
Also used : ADCircuitBreakerService(org.opensearch.ad.breaker.ADCircuitBreakerService) JvmInfo(org.opensearch.monitor.jvm.JvmInfo) ByteSizeValue(org.opensearch.common.unit.ByteSizeValue) EntityModel(org.opensearch.ad.ml.EntityModel) MemoryTracker(org.opensearch.ad.MemoryTracker) Mem(org.opensearch.monitor.jvm.JvmInfo.Mem) JvmService(org.opensearch.monitor.jvm.JvmService) Mockito.anyLong(org.mockito.Mockito.anyLong) Mockito.anyBoolean(org.mockito.Mockito.anyBoolean)

Example 20 with EntityModel

use of org.opensearch.ad.ml.EntityModel in project anomaly-detection by opensearch-project.

the class MLUtil method randomModelState.

public static ModelState<EntityModel> randomModelState(RandomModelStateConfig config) {
    boolean fullModel = config.getFullModel() != null && config.getFullModel().booleanValue() ? true : false;
    float priority = config.getPriority() != null ? config.getPriority() : random.nextFloat();
    String detectorId = config.getDetectorId() != null ? config.getDetectorId() : randomString(15);
    int sampleSize = config.getSampleSize() != null ? config.getSampleSize() : random.nextInt(minSampleSize);
    Clock clock = config.getClock() != null ? config.getClock() : Clock.systemUTC();
    Entity entity = null;
    if (config.hasEntityAttributes()) {
        Map<String, Object> attributes = new HashMap<>();
        attributes.put("a", "a1");
        attributes.put("b", "b1");
        entity = Entity.createEntityByReordering(attributes);
    } else {
        entity = Entity.createSingleAttributeEntity("", "");
    }
    EntityModel model = null;
    if (fullModel) {
        model = createNonEmptyModel(detectorId, sampleSize, entity);
    } else {
        model = createEmptyModel(entity, sampleSize);
    }
    return new ModelState<>(model, detectorId, detectorId, ModelType.ENTITY.getName(), clock, priority);
}
Also used : Entity(org.opensearch.ad.model.Entity) HashMap(java.util.HashMap) EntityModel(org.opensearch.ad.ml.EntityModel) ModelState(org.opensearch.ad.ml.ModelState) Clock(java.time.Clock)

Aggregations

EntityModel (org.opensearch.ad.ml.EntityModel)24 ModelState (org.opensearch.ad.ml.ModelState)8 Entity (org.opensearch.ad.model.Entity)8 ArrayList (java.util.ArrayList)7 Instant (java.time.Instant)5 ParameterizedMessage (org.apache.logging.log4j.message.ParameterizedMessage)4 EntityCache (org.opensearch.ad.caching.EntityCache)4 Clock (java.time.Clock)3 HashMap (java.util.HashMap)3 Before (org.junit.Before)3 ThresholdingResult (org.opensearch.ad.ml.ThresholdingResult)3 AnomalyDetector (org.opensearch.ad.model.AnomalyDetector)3 RandomModelStateConfig (test.org.opensearch.ad.util.RandomModelStateConfig)3 ArrayDeque (java.util.ArrayDeque)2 Map (java.util.Map)2 Optional (java.util.Optional)2 Random (java.util.Random)2 ArgumentMatchers.anyString (org.mockito.ArgumentMatchers.anyString)2 CacheProvider (org.opensearch.ad.caching.CacheProvider)2 AnomalyDetectionIndices (org.opensearch.ad.indices.AnomalyDetectionIndices)2