Search in sources :

Example 1 with ML_EXECUTING_TASK_COUNT

use of org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT in project ml-commons by opensearch-project.

the class MLTaskDispatcher method dispatchTask.

/**
 * Select least loaded node based on ML_EXECUTING_TASK_COUNT and JVM_HEAP_USAGE
 * @param listener Action listener
 */
public void dispatchTask(ActionListener<DiscoveryNode> listener) {
    // todo: add ML node type setting check
    // DiscoveryNode[] mlNodes = getEligibleMLNodes();
    DiscoveryNode[] mlNodes = getEligibleDataNodes();
    MLStatsNodesRequest MLStatsNodesRequest = new MLStatsNodesRequest(mlNodes);
    MLStatsNodesRequest.addAll(ImmutableSet.of(ML_EXECUTING_TASK_COUNT, JVM_HEAP_USAGE.getName()));
    client.execute(MLStatsNodesAction.INSTANCE, MLStatsNodesRequest, ActionListener.wrap(mlStatsResponse -> {
        // Check JVM pressure
        List<MLStatsNodeResponse> candidateNodeResponse = mlStatsResponse.getNodes().stream().filter(stat -> (long) stat.getStatsMap().get(JVM_HEAP_USAGE.getName()) < DEFAULT_JVM_HEAP_USAGE_THRESHOLD).collect(Collectors.toList());
        if (candidateNodeResponse.size() == 0) {
            String errorMessage = "All nodes' memory usage exceeds limitation " + DEFAULT_JVM_HEAP_USAGE_THRESHOLD + ". No eligible node available to run ml jobs ";
            log.warn(errorMessage);
            listener.onFailure(new LimitExceededException(errorMessage));
            return;
        }
        // Check # of executing ML task
        candidateNodeResponse = candidateNodeResponse.stream().filter(stat -> (Long) stat.getStatsMap().get(ML_EXECUTING_TASK_COUNT) < maxMLBatchTaskPerNode).collect(Collectors.toList());
        if (candidateNodeResponse.size() == 0) {
            String errorMessage = "All nodes' executing ML task count reach limitation.";
            log.warn(errorMessage);
            listener.onFailure(new LimitExceededException(errorMessage));
            return;
        }
        // sort nodes by JVM usage percentage and # of executing ML task
        Optional<MLStatsNodeResponse> targetNode = candidateNodeResponse.stream().sorted((MLStatsNodeResponse r1, MLStatsNodeResponse r2) -> {
            int result = ((Long) r1.getStatsMap().get(ML_EXECUTING_TASK_COUNT)).compareTo((Long) r2.getStatsMap().get(ML_EXECUTING_TASK_COUNT));
            if (result == 0) {
                // JVM heap usage.
                return ((Long) r1.getStatsMap().get(JVM_HEAP_USAGE.getName())).compareTo((Long) r2.getStatsMap().get(JVM_HEAP_USAGE.getName()));
            }
            return result;
        }).findFirst();
        listener.onResponse(targetNode.get().getNode());
    }, exception -> {
        log.error("Failed to get node's task stats", exception);
        listener.onFailure(exception);
    }));
}
Also used : Client(org.opensearch.client.Client) ImmutableSet(com.google.common.collect.ImmutableSet) JVM_HEAP_USAGE(org.opensearch.ml.stats.InternalStatNames.JVM_HEAP_USAGE) Collectors(java.util.stream.Collectors) ML_EXECUTING_TASK_COUNT(org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT) ArrayList(java.util.ArrayList) List(java.util.List) ClusterState(org.opensearch.cluster.ClusterState) DiscoveryNode(org.opensearch.cluster.node.DiscoveryNode) MLStatsNodeResponse(org.opensearch.ml.action.stats.MLStatsNodeResponse) LimitExceededException(javax.naming.LimitExceededException) MLNodeUtils(org.opensearch.ml.utils.MLNodeUtils) Log4j2(lombok.extern.log4j.Log4j2) ClusterService(org.opensearch.cluster.service.ClusterService) Optional(java.util.Optional) ActionListener(org.opensearch.action.ActionListener) MLStatsNodesRequest(org.opensearch.ml.action.stats.MLStatsNodesRequest) MLStatsNodesAction(org.opensearch.ml.action.stats.MLStatsNodesAction) MLStatsNodeResponse(org.opensearch.ml.action.stats.MLStatsNodeResponse) DiscoveryNode(org.opensearch.cluster.node.DiscoveryNode) Optional(java.util.Optional) ArrayList(java.util.ArrayList) List(java.util.List) LimitExceededException(javax.naming.LimitExceededException) MLStatsNodesRequest(org.opensearch.ml.action.stats.MLStatsNodesRequest)

Aggregations

ImmutableSet (com.google.common.collect.ImmutableSet)1 ArrayList (java.util.ArrayList)1 List (java.util.List)1 Optional (java.util.Optional)1 Collectors (java.util.stream.Collectors)1 LimitExceededException (javax.naming.LimitExceededException)1 Log4j2 (lombok.extern.log4j.Log4j2)1 ActionListener (org.opensearch.action.ActionListener)1 Client (org.opensearch.client.Client)1 ClusterState (org.opensearch.cluster.ClusterState)1 DiscoveryNode (org.opensearch.cluster.node.DiscoveryNode)1 ClusterService (org.opensearch.cluster.service.ClusterService)1 MLStatsNodeResponse (org.opensearch.ml.action.stats.MLStatsNodeResponse)1 MLStatsNodesAction (org.opensearch.ml.action.stats.MLStatsNodesAction)1 MLStatsNodesRequest (org.opensearch.ml.action.stats.MLStatsNodesRequest)1 JVM_HEAP_USAGE (org.opensearch.ml.stats.InternalStatNames.JVM_HEAP_USAGE)1 ML_EXECUTING_TASK_COUNT (org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT)1 MLNodeUtils (org.opensearch.ml.utils.MLNodeUtils)1