use of org.opensearch.ad.transport.ADTaskProfileRequest in project anomaly-detection by opensearch-project.
the class ADTaskManager method getADTaskProfile.
/**
* Get AD task profile.
* @param adDetectorLevelTask detector level task
* @param listener action listener
*/
private void getADTaskProfile(ADTask adDetectorLevelTask, ActionListener<ADTaskProfile> listener) {
String detectorId = adDetectorLevelTask.getDetectorId();
hashRing.getAllEligibleDataNodesWithKnownAdVersion(dataNodes -> {
ADTaskProfileRequest adTaskProfileRequest = new ADTaskProfileRequest(detectorId, dataNodes);
client.execute(ADTaskProfileAction.INSTANCE, adTaskProfileRequest, ActionListener.wrap(response -> {
if (response.hasFailures()) {
listener.onFailure(response.failures().get(0));
return;
}
List<ADEntityTaskProfile> adEntityTaskProfiles = new ArrayList<>();
ADTaskProfile detectorTaskProfile = new ADTaskProfile(adDetectorLevelTask);
for (ADTaskProfileNodeResponse node : response.getNodes()) {
ADTaskProfile taskProfile = node.getAdTaskProfile();
if (taskProfile != null) {
if (taskProfile.getNodeId() != null) {
// HC detector: task profile from coordinating node
// Single entity detector: task profile from worker node
detectorTaskProfile.setTaskId(taskProfile.getTaskId());
detectorTaskProfile.setShingleSize(taskProfile.getShingleSize());
detectorTaskProfile.setRcfTotalUpdates(taskProfile.getRcfTotalUpdates());
detectorTaskProfile.setThresholdModelTrained(taskProfile.getThresholdModelTrained());
detectorTaskProfile.setThresholdModelTrainingDataSize(taskProfile.getThresholdModelTrainingDataSize());
detectorTaskProfile.setModelSizeInBytes(taskProfile.getModelSizeInBytes());
detectorTaskProfile.setNodeId(taskProfile.getNodeId());
detectorTaskProfile.setTotalEntitiesCount(taskProfile.getTotalEntitiesCount());
detectorTaskProfile.setDetectorTaskSlots(taskProfile.getDetectorTaskSlots());
detectorTaskProfile.setPendingEntitiesCount(taskProfile.getPendingEntitiesCount());
detectorTaskProfile.setRunningEntitiesCount(taskProfile.getRunningEntitiesCount());
detectorTaskProfile.setRunningEntities(taskProfile.getRunningEntities());
detectorTaskProfile.setAdTaskType(taskProfile.getAdTaskType());
}
if (taskProfile.getEntityTaskProfiles() != null) {
adEntityTaskProfiles.addAll(taskProfile.getEntityTaskProfiles());
}
}
}
if (adEntityTaskProfiles != null && adEntityTaskProfiles.size() > 0) {
detectorTaskProfile.setEntityTaskProfiles(adEntityTaskProfiles);
}
listener.onResponse(detectorTaskProfile);
}, e -> {
logger.error("Failed to get task profile for task " + adDetectorLevelTask.getTaskId(), e);
listener.onFailure(e);
}));
}, listener);
}
Aggregations