use of org.opensearch.ad.transport.ADTaskProfileNodeResponse in project anomaly-detection by opensearch-project.
the class ADTaskManager method getADTaskProfile.
/**
* Get AD task profile.
* @param adDetectorLevelTask detector level task
* @param listener action listener
*/
private void getADTaskProfile(ADTask adDetectorLevelTask, ActionListener<ADTaskProfile> listener) {
String detectorId = adDetectorLevelTask.getDetectorId();
hashRing.getAllEligibleDataNodesWithKnownAdVersion(dataNodes -> {
ADTaskProfileRequest adTaskProfileRequest = new ADTaskProfileRequest(detectorId, dataNodes);
client.execute(ADTaskProfileAction.INSTANCE, adTaskProfileRequest, ActionListener.wrap(response -> {
if (response.hasFailures()) {
listener.onFailure(response.failures().get(0));
return;
}
List<ADEntityTaskProfile> adEntityTaskProfiles = new ArrayList<>();
ADTaskProfile detectorTaskProfile = new ADTaskProfile(adDetectorLevelTask);
for (ADTaskProfileNodeResponse node : response.getNodes()) {
ADTaskProfile taskProfile = node.getAdTaskProfile();
if (taskProfile != null) {
if (taskProfile.getNodeId() != null) {
// HC detector: task profile from coordinating node
// Single entity detector: task profile from worker node
detectorTaskProfile.setTaskId(taskProfile.getTaskId());
detectorTaskProfile.setShingleSize(taskProfile.getShingleSize());
detectorTaskProfile.setRcfTotalUpdates(taskProfile.getRcfTotalUpdates());
detectorTaskProfile.setThresholdModelTrained(taskProfile.getThresholdModelTrained());
detectorTaskProfile.setThresholdModelTrainingDataSize(taskProfile.getThresholdModelTrainingDataSize());
detectorTaskProfile.setModelSizeInBytes(taskProfile.getModelSizeInBytes());
detectorTaskProfile.setNodeId(taskProfile.getNodeId());
detectorTaskProfile.setTotalEntitiesCount(taskProfile.getTotalEntitiesCount());
detectorTaskProfile.setDetectorTaskSlots(taskProfile.getDetectorTaskSlots());
detectorTaskProfile.setPendingEntitiesCount(taskProfile.getPendingEntitiesCount());
detectorTaskProfile.setRunningEntitiesCount(taskProfile.getRunningEntitiesCount());
detectorTaskProfile.setRunningEntities(taskProfile.getRunningEntities());
detectorTaskProfile.setAdTaskType(taskProfile.getAdTaskType());
}
if (taskProfile.getEntityTaskProfiles() != null) {
adEntityTaskProfiles.addAll(taskProfile.getEntityTaskProfiles());
}
}
}
if (adEntityTaskProfiles != null && adEntityTaskProfiles.size() > 0) {
detectorTaskProfile.setEntityTaskProfiles(adEntityTaskProfiles);
}
listener.onResponse(detectorTaskProfile);
}, e -> {
logger.error("Failed to get task profile for task " + adDetectorLevelTask.getTaskId(), e);
listener.onFailure(e);
}));
}, listener);
}
use of org.opensearch.ad.transport.ADTaskProfileNodeResponse in project anomaly-detection by opensearch-project.
the class ADTaskManagerTests method setupGetAndExecuteOnLatestADTasks.
@SuppressWarnings("unchecked")
private void setupGetAndExecuteOnLatestADTasks(ADTaskProfile adTaskProfile) {
String runningRealtimeHCTaskContent = runningHistoricalHCTaskContent.replace(ADTaskType.HISTORICAL_HC_DETECTOR.name(), ADTaskType.REALTIME_HC_DETECTOR.name()).replace(historicalTaskId, realtimeTaskId);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
SearchHit historicalTask = SearchHit.fromXContent(TestHelpers.parser(runningHistoricalHCTaskContent));
SearchHit realtimeTask = SearchHit.fromXContent(TestHelpers.parser(runningRealtimeHCTaskContent));
SearchHits searchHits = new SearchHits(new SearchHit[] { historicalTask, realtimeTask }, new TotalHits(2, TotalHits.Relation.EQUAL_TO), Float.NaN);
InternalSearchResponse response = new InternalSearchResponse(searchHits, InternalAggregations.EMPTY, null, null, false, null, 1);
SearchResponse searchResponse = new SearchResponse(response, null, 1, 1, 0, 100, ShardSearchFailure.EMPTY_ARRAY, SearchResponse.Clusters.EMPTY);
listener.onResponse(searchResponse);
return null;
}).when(client).search(any(), any());
String detectorId = randomAlphaOfLength(5);
Consumer<List<ADTask>> function = mock(Consumer.class);
ActionListener<AnomalyDetectorJobResponse> listener = mock(ActionListener.class);
doAnswer(invocation -> {
Consumer<DiscoveryNode[]> getNodeFunction = invocation.getArgument(0);
getNodeFunction.accept(new DiscoveryNode[] { node1, node2 });
return null;
}).when(hashRing).getAllEligibleDataNodesWithKnownAdVersion(any(), any());
doAnswer(invocation -> {
ActionListener<ADTaskProfileResponse> taskProfileResponseListener = invocation.getArgument(2);
AnomalyDetector detector = TestHelpers.randomDetector(ImmutableList.of(randomFeature(true)), randomAlphaOfLength(5), randomIntBetween(1, 10), MockSimpleLog.TIME_FIELD, ImmutableList.of(randomAlphaOfLength(5)));
ADTaskProfileNodeResponse nodeResponse = new ADTaskProfileNodeResponse(node1, adTaskProfile, Version.CURRENT);
ImmutableList<ADTaskProfileNodeResponse> nodes = ImmutableList.of(nodeResponse);
ADTaskProfileResponse taskProfileResponse = new ADTaskProfileResponse(new ClusterName("test"), nodes, ImmutableList.of());
taskProfileResponseListener.onResponse(taskProfileResponse);
return null;
}).doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> updateResponselistener = invocation.getArgument(2);
BulkByScrollResponse response = mock(BulkByScrollResponse.class);
when(response.getBulkFailures()).thenReturn(null);
updateResponselistener.onResponse(response);
return null;
}).when(client).execute(any(), any(), any());
when(nodeFilter.getEligibleDataNodes()).thenReturn(new DiscoveryNode[] { node1, node2 });
doAnswer(invocation -> {
ActionListener<UpdateResponse> updateResponselistener = invocation.getArgument(1);
UpdateResponse response = new UpdateResponse(ShardId.fromString("[test][1]"), CommonName.MAPPING_TYPE, "1", 0L, 1L, 1L, DocWriteResponse.Result.UPDATED);
updateResponselistener.onResponse(response);
return null;
}).when(client).update(any(), any());
doAnswer(invocation -> {
ActionListener<GetResponse> getResponselistener = invocation.getArgument(1);
GetResponse response = new GetResponse(new GetResult(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX, MapperService.SINGLE_MAPPING_NAME, detectorId, UNASSIGNED_SEQ_NO, 0, -1, true, BytesReference.bytes(new AnomalyDetectorJob(detectorId, randomIntervalSchedule(), randomIntervalTimeConfiguration(), false, Instant.now().minusSeconds(60), Instant.now(), Instant.now(), 60L, TestHelpers.randomUser(), null).toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)), Collections.emptyMap(), Collections.emptyMap()));
getResponselistener.onResponse(response);
return null;
}).when(client).get(any(), any());
}
Aggregations