/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.undeploy;

import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.bulk.BulkItemResponse;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.support.nodes.BaseNodesRequest;
import org.opensearch.action.support.nodes.TransportNodesAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.remote.metadata.client.BulkDataObjectRequest;
import org.opensearch.remote.metadata.client.DataObjectRequest;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.client.UpdateDataObjectRequest;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

public class TransportUndeployModelAction
extends TransportNodesAction<MLUndeployModelNodesRequest, MLUndeployModelNodesResponse, MLUndeployModelNodeRequest, MLUndeployModelNodeResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportUndeployModelAction.class);
    private final MLModelManager mlModelManager;
    private final ClusterService clusterService;
    private final Client client;
    private final SdkClient sdkClient;
    private final DiscoveryNodeHelper nodeFilter;
    private final MLStats mlStats;

    @Inject
    public TransportUndeployModelAction(TransportService transportService, ActionFilters actionFilters, MLModelManager mlModelManager, ClusterService clusterService, ThreadPool threadPool, Client client, SdkClient sdkClient, DiscoveryNodeHelper nodeFilter, MLStats mlStats) {
        super("cluster:admin/opensearch/ml/undeploy_model", threadPool, clusterService, transportService, actionFilters, MLUndeployModelNodesRequest::new, MLUndeployModelNodeRequest::new, "management", MLUndeployModelNodeResponse.class);
        this.mlModelManager = mlModelManager;
        this.clusterService = clusterService;
        this.client = client;
        this.sdkClient = sdkClient;
        this.nodeFilter = nodeFilter;
        this.mlStats = mlStats;
    }

    protected void doExecute(Task task, MLUndeployModelNodesRequest request, ActionListener<MLUndeployModelNodesResponse> listener) {
        ActionListener wrappedListener = ActionListener.wrap(undeployModelNodesResponse -> this.processUndeployModelResponseAndUpdate(request.getTenantId(), (MLUndeployModelNodesResponse)undeployModelNodesResponse, listener), arg_0 -> listener.onFailure(arg_0));
        super.doExecute(task, (BaseNodesRequest)request, wrappedListener);
    }

    void processUndeployModelResponseAndUpdate(String tenantId, MLUndeployModelNodesResponse undeployModelNodesResponse, ActionListener<MLUndeployModelNodesResponse> listener) {
        List responses = undeployModelNodesResponse.getNodes();
        if (responses == null || responses.isEmpty()) {
            listener.onResponse((Object)undeployModelNodesResponse);
            return;
        }
        HashMap<String, List<String>> actualRemovedNodesMap = new HashMap<String, List<String>>();
        HashMap modelWorkNodesBeforeRemoval = new HashMap();
        responses.forEach(r -> {
            Map nodeCounts = r.getModelWorkerNodeBeforeRemoval();
            if (nodeCounts != null) {
                for (Map.Entry entry : nodeCounts.entrySet()) {
                    if (entry.getValue() == null || modelWorkNodesBeforeRemoval.containsKey(entry.getKey()) && ((String[])modelWorkNodesBeforeRemoval.get(entry.getKey())).length >= ((String[])entry.getValue()).length) continue;
                    modelWorkNodesBeforeRemoval.put((String)entry.getKey(), (String[])entry.getValue());
                }
            }
            Map modelUndeployStatus = r.getModelUndeployStatus();
            for (Map.Entry entry : modelUndeployStatus.entrySet()) {
                String status = (String)entry.getValue();
                if (!"undeployed".equals(status)) continue;
                String modelId = (String)entry.getKey();
                if (!actualRemovedNodesMap.containsKey(modelId)) {
                    actualRemovedNodesMap.put(modelId, new ArrayList());
                }
                ((List)actualRemovedNodesMap.get(modelId)).add(r.getNode().getId());
            }
        });
        MLSyncUpInput syncUpInput = MLSyncUpInput.builder().removedWorkerNodes(this.covertRemoveNodesMapForSyncUp(actualRemovedNodesMap)).build();
        MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(this.nodeFilter.getAllNodes(), syncUpInput);
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            if (!actualRemovedNodesMap.isEmpty()) {
                BulkDataObjectRequest bulkRequest = BulkDataObjectRequest.builder().globalIndex(".plugins-ml-model").build();
                HashMap<String, Boolean> deployToAllNodes = new HashMap<String, Boolean>();
                for (String modelId : actualRemovedNodesMap.keySet()) {
                    List removedNodes = (List)actualRemovedNodesMap.get(modelId);
                    int removedNodeCount = removedNodes.size();
                    HashMap<String, Object> updateDocument = new HashMap<String, Object>();
                    if (((String[])modelWorkNodesBeforeRemoval.get(modelId)).length == removedNodeCount) {
                        updateDocument.put("planning_worker_nodes", ImmutableList.of());
                        updateDocument.put("planning_worker_node_count", 0);
                        updateDocument.put("current_worker_node_count", 0);
                        updateDocument.put("model_state", MLModelState.UNDEPLOYED);
                    } else {
                        updateDocument.put("deploy_to_all_nodes", false);
                        List newPlanningWorkerNodes = Arrays.stream((String[])modelWorkNodesBeforeRemoval.get(modelId)).filter(x -> !removedNodes.contains(x)).collect(Collectors.toList());
                        updateDocument.put("planning_worker_nodes", newPlanningWorkerNodes);
                        updateDocument.put("planning_worker_node_count", newPlanningWorkerNodes.size());
                        updateDocument.put("current_worker_node_count", newPlanningWorkerNodes.size());
                        deployToAllNodes.put(modelId, false);
                    }
                    UpdateDataObjectRequest updateRequest = ((UpdateDataObjectRequest.Builder)((UpdateDataObjectRequest.Builder)UpdateDataObjectRequest.builder().id(modelId)).tenantId(tenantId)).dataObject(updateDocument).build();
                    bulkRequest.add((DataObjectRequest)updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                }
                syncUpInput.setDeployToAllNodes(deployToAllNodes);
                ActionListener actionListener = ActionListener.wrap(r -> log.debug("updated model state as undeployed for : {}", (Object)Arrays.toString(actualRemovedNodesMap.keySet().toArray(new String[0]))), e -> log.error("Failed to update model state as undeployed", (Throwable)e));
                ActionListener wrappedListener = ActionListener.runAfter((ActionListener)actionListener, () -> {
                    this.syncUpUndeployedModels(syncUpRequest);
                    listener.onResponse((Object)undeployModelNodesResponse);
                });
                this.sdkClient.bulkDataObjectAsync(bulkRequest).whenComplete((r, throwable) -> {
                    if (throwable != null) {
                        Exception cause = SdkClientUtils.unwrapAndConvertToException((Throwable)throwable, (Class[])new Class[]{OpenSearchStatusException.class});
                        log.error("Failed to execute BulkDataObject request", (Throwable)cause);
                        wrappedListener.onFailure(cause);
                    } else {
                        try {
                            BulkResponse bulkResponse = BulkResponse.fromXContent((XContentParser)r.parser());
                            log.info("Executed {} bulk operations with {} failures, Took: {}", (Object)bulkResponse.getItems().length, (Object)(bulkResponse.hasFailures() ? Arrays.stream(bulkResponse.getItems()).filter(BulkItemResponse::isFailed).count() : 0L), (Object)bulkResponse.getTook());
                            wrappedListener.onResponse((Object)bulkResponse);
                        }
                        catch (Exception e) {
                            wrappedListener.onFailure(e);
                        }
                    }
                });
            } else {
                this.syncUpUndeployedModels(syncUpRequest);
                listener.onResponse((Object)undeployModelNodesResponse);
            }
        }
    }

    protected MLUndeployModelNodesResponse newResponse(MLUndeployModelNodesRequest nodesRequest, List<MLUndeployModelNodeResponse> responses, List<FailedNodeException> failures) {
        return new MLUndeployModelNodesResponse(this.clusterService.getClusterName(), responses, failures);
    }

    private Map<String, String[]> covertRemoveNodesMapForSyncUp(Map<String, List<String>> actualRemovedNodesMap) {
        HashMap<String, String[]> removedNodesMap = new HashMap<String, String[]>();
        for (Map.Entry<String, List<String>> entry : actualRemovedNodesMap.entrySet()) {
            removedNodesMap.put(entry.getKey(), entry.getValue().toArray(new String[0]));
            log.debug("removed node for model: {}, {}", (Object)entry.getKey(), (Object)Arrays.toString(entry.getValue().toArray(new String[0])));
        }
        return removedNodesMap;
    }

    private void syncUpUndeployedModels(MLSyncUpNodesRequest syncUpRequest) {
        this.client.execute((ActionType)MLSyncUpAction.INSTANCE, (ActionRequest)syncUpRequest, ActionListener.wrap(r -> log.debug("sync up removed nodes successfully"), e -> log.error("failed to sync up removed node", (Throwable)e)));
    }

    protected MLUndeployModelNodeRequest newNodeRequest(MLUndeployModelNodesRequest request) {
        return new MLUndeployModelNodeRequest(request);
    }

    protected MLUndeployModelNodeResponse newNodeResponse(StreamInput in) throws IOException {
        return new MLUndeployModelNodeResponse(in);
    }

    protected MLUndeployModelNodeResponse nodeOperation(MLUndeployModelNodeRequest request) {
        return this.createUndeployModelNodeResponse(request.getMlUndeployModelNodesRequest());
    }

    private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployModelNodesRequest MLUndeployModelNodesRequest2) {
        String[] removedModelIds;
        this.mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();
        String[] modelIds = MLUndeployModelNodesRequest2.getModelIds();
        HashMap<String, String[]> modelWorkerNodesMap = new HashMap<String, String[]>();
        boolean specifiedModelIds = modelIds != null && modelIds.length > 0;
        String[] stringArray = removedModelIds = specifiedModelIds ? modelIds : this.mlModelManager.getAllModelIds();
        if (removedModelIds != null) {
            for (String modelId : removedModelIds) {
                FunctionName functionName = this.mlModelManager.getModelFunctionName(modelId);
                String[] workerNodes = this.mlModelManager.getWorkerNodes(modelId, functionName);
                modelWorkerNodesMap.put(modelId, workerNodes);
            }
        }
        Map<String, String> modelUndeployStatus = this.mlModelManager.undeployModel(modelIds);
        this.mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement();
        return new MLUndeployModelNodeResponse(this.clusterService.localNode(), modelUndeployStatus, modelWorkerNodesMap);
    }
}

