/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.task;

import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.naming.LimitExceededException;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.ml.action.stats.MLStatsNodeResponse;
import org.opensearch.ml.action.stats.MLStatsNodesAction;
import org.opensearch.ml.action.stats.MLStatsNodesRequest;
import org.opensearch.ml.stats.InternalStatNames;
import org.opensearch.ml.stats.StatNames;

public class MLTaskDispatcher {
    @Generated
    private static final Logger log = LogManager.getLogger(MLTaskDispatcher.class);
    private final short DEFAULT_JVM_HEAP_USAGE_THRESHOLD = (short)85;
    private final ClusterService clusterService;
    private final Client client;
    private volatile Integer maxMLBatchTaskPerNode;

    public MLTaskDispatcher(ClusterService clusterService, Client client) {
        this.clusterService = clusterService;
        this.client = client;
        this.maxMLBatchTaskPerNode = 10;
    }

    public void dispatchTask(ActionListener<DiscoveryNode> listener) {
        DiscoveryNode[] mlNodes = this.getEligibleDataNodes();
        MLStatsNodesRequest MLStatsNodesRequest2 = new MLStatsNodesRequest(mlNodes);
        MLStatsNodesRequest2.addAll((Set<String>)ImmutableSet.of((Object)StatNames.ML_EXECUTING_TASK_COUNT, (Object)InternalStatNames.JVM_HEAP_USAGE.getName()));
        this.client.execute((ActionType)MLStatsNodesAction.INSTANCE, (ActionRequest)MLStatsNodesRequest2, ActionListener.wrap(mlStatsResponse -> {
            List candidateNodeResponse = mlStatsResponse.getNodes().stream().filter(stat -> (Long)stat.getStatsMap().get(InternalStatNames.JVM_HEAP_USAGE.getName()) < 85L).collect(Collectors.toList());
            if (candidateNodeResponse.size() == 0) {
                String errorMessage = "All nodes' memory usage exceeds limitation 85. No eligible node available to run ml jobs ";
                log.warn(errorMessage);
                listener.onFailure((Exception)new LimitExceededException(errorMessage));
                return;
            }
            if ((candidateNodeResponse = candidateNodeResponse.stream().filter(stat -> (Long)stat.getStatsMap().get(StatNames.ML_EXECUTING_TASK_COUNT) < (long)this.maxMLBatchTaskPerNode.intValue()).collect(Collectors.toList())).size() == 0) {
                String errorMessage = "All nodes' executing ML task count reach limitation.";
                log.warn(errorMessage);
                listener.onFailure((Exception)new LimitExceededException(errorMessage));
                return;
            }
            Optional targetNode = candidateNodeResponse.stream().sorted((r1, r2) -> {
                int result = ((Long)r1.getStatsMap().get(StatNames.ML_EXECUTING_TASK_COUNT)).compareTo((Long)r2.getStatsMap().get(StatNames.ML_EXECUTING_TASK_COUNT));
                if (result == 0) {
                    return ((Long)r1.getStatsMap().get(InternalStatNames.JVM_HEAP_USAGE.getName())).compareTo((Long)r2.getStatsMap().get(InternalStatNames.JVM_HEAP_USAGE.getName()));
                }
                return result;
            }).findFirst();
            listener.onResponse((Object)((MLStatsNodeResponse)((Object)((Object)targetNode.get()))).getNode());
        }, exception -> {
            log.error("Failed to get node's task stats", (Throwable)exception);
            listener.onFailure(exception);
        }));
    }

    private DiscoveryNode[] getEligibleDataNodes() {
        ClusterState state = this.clusterService.state();
        ArrayList<DiscoveryNode> eligibleDataNodes = new ArrayList<DiscoveryNode>();
        for (DiscoveryNode node : state.nodes()) {
            if (!node.isDataNode()) continue;
            eligibleDataNodes.add(node);
        }
        return eligibleDataNodes.toArray(new DiscoveryNode[0]);
    }
}

