use of org.opensearch.ml.action.stats.MLStatsNodeResponse in project ml-commons by opensearch-project.
the class MLTaskDispatcher method dispatchTask.
/**
* Select least loaded node based on ML_EXECUTING_TASK_COUNT and JVM_HEAP_USAGE
* @param listener Action listener
*/
public void dispatchTask(ActionListener<DiscoveryNode> listener) {
// todo: add ML node type setting check
// DiscoveryNode[] mlNodes = getEligibleMLNodes();
DiscoveryNode[] mlNodes = getEligibleDataNodes();
MLStatsNodesRequest MLStatsNodesRequest = new MLStatsNodesRequest(mlNodes);
MLStatsNodesRequest.addAll(ImmutableSet.of(ML_EXECUTING_TASK_COUNT, JVM_HEAP_USAGE.getName()));
client.execute(MLStatsNodesAction.INSTANCE, MLStatsNodesRequest, ActionListener.wrap(mlStatsResponse -> {
// Check JVM pressure
List<MLStatsNodeResponse> candidateNodeResponse = mlStatsResponse.getNodes().stream().filter(stat -> (long) stat.getStatsMap().get(JVM_HEAP_USAGE.getName()) < DEFAULT_JVM_HEAP_USAGE_THRESHOLD).collect(Collectors.toList());
if (candidateNodeResponse.size() == 0) {
String errorMessage = "All nodes' memory usage exceeds limitation " + DEFAULT_JVM_HEAP_USAGE_THRESHOLD + ". No eligible node available to run ml jobs ";
log.warn(errorMessage);
listener.onFailure(new LimitExceededException(errorMessage));
return;
}
// Check # of executing ML task
candidateNodeResponse = candidateNodeResponse.stream().filter(stat -> (Long) stat.getStatsMap().get(ML_EXECUTING_TASK_COUNT) < maxMLBatchTaskPerNode).collect(Collectors.toList());
if (candidateNodeResponse.size() == 0) {
String errorMessage = "All nodes' executing ML task count reach limitation.";
log.warn(errorMessage);
listener.onFailure(new LimitExceededException(errorMessage));
return;
}
// sort nodes by JVM usage percentage and # of executing ML task
Optional<MLStatsNodeResponse> targetNode = candidateNodeResponse.stream().sorted((MLStatsNodeResponse r1, MLStatsNodeResponse r2) -> {
int result = ((Long) r1.getStatsMap().get(ML_EXECUTING_TASK_COUNT)).compareTo((Long) r2.getStatsMap().get(ML_EXECUTING_TASK_COUNT));
if (result == 0) {
// JVM heap usage.
return ((Long) r1.getStatsMap().get(JVM_HEAP_USAGE.getName())).compareTo((Long) r2.getStatsMap().get(JVM_HEAP_USAGE.getName()));
}
return result;
}).findFirst();
listener.onResponse(targetNode.get().getNode());
}, exception -> {
log.error("Failed to get node's task stats", exception);
listener.onFailure(exception);
}));
}
Aggregations