/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.forward;

import com.google.common.collect.ImmutableMap;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.autoredeploy.MLModelAutoReDeployer;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.forward.MLForwardInput;
import org.opensearch.ml.common.transport.forward.MLForwardRequest;
import org.opensearch.ml.common.transport.forward.MLForwardRequestType;
import org.opensearch.ml.common.transport.forward.MLForwardResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskCache;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

public class TransportForwardAction
extends HandledTransportAction<ActionRequest, MLForwardResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportForwardAction.class);
    private final ClusterService clusterService;
    private MLTaskManager mlTaskManager;
    private Client client;
    private final SdkClient sdkClient;
    private MLModelManager mlModelManager;
    private DiscoveryNodeHelper nodeHelper;
    private final Settings settings;
    private volatile float modelAutoRedeploySuccessRatio;
    private boolean enableAutoReDeployModel;
    private final MLModelAutoReDeployer mlModelAutoReDeployer;
    private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

    @Inject
    public TransportForwardAction(TransportService transportService, ActionFilters actionFilters, MLTaskManager mlTaskManager, Client client, SdkClient sdkClient, MLModelManager mlModelManager, DiscoveryNodeHelper nodeHelper, Settings settings, ClusterService clusterService, MLModelAutoReDeployer mlModelAutoReDeployer, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
        super("cluster:admin/opensearch/mlinternal/forward", transportService, actionFilters, MLForwardRequest::new);
        this.mlTaskManager = mlTaskManager;
        this.client = client;
        this.sdkClient = sdkClient;
        this.mlModelManager = mlModelManager;
        this.nodeHelper = nodeHelper;
        this.settings = settings;
        this.clusterService = clusterService;
        this.mlModelAutoReDeployer = mlModelAutoReDeployer;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
        this.modelAutoRedeploySuccessRatio = ((Float)MLCommonsSettings.ML_COMMONS_MODEL_AUTO_REDEPLOY_SUCCESS_RATIO.get(settings)).floatValue();
        this.enableAutoReDeployModel = (Boolean)MLCommonsSettings.ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE, it -> {
            this.enableAutoReDeployModel = it;
        });
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLForwardResponse> listener) {
        MLForwardRequest mlForwardRequest = MLForwardRequest.fromActionRequest((ActionRequest)request);
        MLForwardInput forwardInput = mlForwardRequest.getForwardInput();
        String modelId = forwardInput.getModelId();
        String taskId = forwardInput.getTaskId();
        String tenantId = forwardInput.getTenantId();
        if (!TenantAwareHelper.validateTenantId(this.mlFeatureEnabledSetting, tenantId, listener)) {
            return;
        }
        MLRegisterModelInput registerModelInput = forwardInput.getRegisterModelInput();
        MLTask mlTask = forwardInput.getMlTask();
        String workerNodeId = forwardInput.getWorkerNodeId();
        MLForwardRequestType requestType = forwardInput.getRequestType();
        String error = forwardInput.getError();
        log.debug("receive forward request: {}", (Object)forwardInput.getRequestType());
        try {
            switch (requestType) {
                case DEPLOY_MODEL_DONE: {
                    Set<String> workNodes = this.mlTaskManager.getWorkNodes(taskId);
                    MLTaskCache mlTaskCache = this.mlTaskManager.getMLTaskCache(taskId);
                    FunctionName functionName = mlTaskCache.getMlTask().getFunctionName();
                    if (workNodes != null) {
                        workNodes.remove(workerNodeId);
                    }
                    if (error != null) {
                        this.mlTaskManager.addNodeError(taskId, workerNodeId, error);
                    } else {
                        this.mlModelManager.addModelWorkerNode(modelId, workerNodeId);
                        this.syncModelWorkerNodes(modelId, functionName);
                    }
                    Set<String> workNodesRemovedFromCluster = new HashSet<String>();
                    if (workNodes != null && !workNodes.isEmpty()) {
                        HashSet<String> allNodesInCluster = new HashSet<String>(List.of(RestActionUtils.getAllNodes(this.clusterService)));
                        workNodesRemovedFromCluster = workNodes.stream().filter(node -> !allNodesInCluster.contains(node)).collect(Collectors.toSet());
                        if (!workNodesRemovedFromCluster.isEmpty()) {
                            workNodes.removeAll(workNodesRemovedFromCluster);
                        }
                    }
                    if (workNodes == null || workNodes.isEmpty()) {
                        MLModelState modelState;
                        MLTaskState taskState;
                        if (!workNodesRemovedFromCluster.isEmpty()) {
                            mlTaskCache.updateWorkerNode(workNodesRemovedFromCluster);
                            this.mlModelManager.removeModelWorkerNode(modelId, false, workNodesRemovedFromCluster.toArray(new String[0]));
                        }
                        int currentWorkerNodeCount = mlTaskCache.getWorkerNodeSize();
                        MLTaskState mLTaskState = taskState = mlTaskCache.hasError() ? MLTaskState.COMPLETED_WITH_ERROR : MLTaskState.COMPLETED;
                        if (mlTaskCache.allNodeFailed() || mlTaskCache.getWorkerNodeSize() == 0) {
                            taskState = MLTaskState.FAILED;
                            currentWorkerNodeCount = 0;
                        } else {
                            this.syncModelWorkerNodes(modelId, functionName);
                        }
                        ImmutableMap.Builder builder = ImmutableMap.builder();
                        builder.put((Object)"state", (Object)taskState);
                        if (mlTaskCache.hasError()) {
                            currentWorkerNodeCount = mlTaskCache.getWorkerNodeSize() - mlTaskCache.getErrors().size();
                            builder.put((Object)"error", (Object)MLExceptionUtils.toJsonString(mlTaskCache.getErrors()));
                        }
                        boolean clearAutoReDeployRetryTimes = this.triggerNextModelDeployAndCheckIfRestRetryTimes(workNodes, taskId);
                        this.mlTaskManager.updateMLTask(taskId, tenantId, (Map<String, Object>)builder.build(), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
                        if (mlTaskCache.allNodeFailed() || mlTaskCache.getWorkerNodeSize() == 0) {
                            modelState = MLModelState.DEPLOY_FAILED;
                            log.error("deploy model failed on all nodes, model id: {}", (Object)modelId);
                        } else {
                            modelState = mlTaskCache.hasError() ? MLModelState.PARTIALLY_DEPLOYED : MLModelState.DEPLOYED;
                        }
                        HashMap<String, Object> updateFields = new HashMap<String, Object>();
                        updateFields.put("model_state", modelState);
                        updateFields.put("last_deployed_time", Instant.now().toEpochMilli());
                        updateFields.put("current_worker_node_count", currentWorkerNodeCount);
                        if (clearAutoReDeployRetryTimes) {
                            log.debug("Model successfully deployed in cluster, setting the auto retry times to 0");
                            updateFields.put("auto_redeploy_retry_times", 0);
                        }
                        log.info("deploy model done with state: {}, model id: {}", (Object)modelState, (Object)modelId);
                        ActionListener updateModelListener = ActionListener.wrap(response -> {
                            if (response.status() == RestStatus.OK) {
                                log.debug("Updated ML model successfully: {}, model id: {}", (Object)response.status(), (Object)modelId);
                            } else {
                                log.error("Failed to update ML model {}, status: {}", (Object)modelId, (Object)response.status());
                            }
                        }, e -> log.error("Failed to update ML model: {}", (Object)modelId, e));
                        this.mlModelManager.updateModel(modelId, tenantId, updateFields, (ActionListener<UpdateResponse>)ActionListener.runBefore((ActionListener)updateModelListener, () -> this.mlModelManager.removeAutoDeployModel(modelId)));
                    }
                    listener.onResponse((Object)new MLForwardResponse("ok", null));
                    break;
                }
                case REGISTER_MODEL: {
                    this.mlModelManager.registerMLModel(registerModelInput, mlTask);
                    listener.onResponse((Object)new MLForwardResponse("ok", null));
                    break;
                }
                default: {
                    throw new IllegalArgumentException("unsupported request type");
                }
            }
        }
        catch (Exception e2) {
            MLExceptionUtils.logException("Failed to execute forward action " + String.valueOf(forwardInput.getRequestType()), e2, log);
            listener.onFailure(e2);
        }
    }

    private boolean triggerNextModelDeployAndCheckIfRestRetryTimes(Set<String> workNodes, String taskId) {
        MLTaskCache mlTaskCache;
        int expectedWorkerNodeCount;
        int receivedWorkerNodesCount;
        int successWorkerNodesCount;
        if (this.enableAutoReDeployModel && workNodes != null && this.mlTaskManager.getMLTaskCache(taskId) != null && (float)(successWorkerNodesCount = (receivedWorkerNodesCount = (expectedWorkerNodeCount = (mlTaskCache = this.mlTaskManager.getMLTaskCache(taskId)).getWorkerNodeSize().intValue()) - workNodes.size()) - mlTaskCache.errorNodesCount()) / (float)expectedWorkerNodeCount >= this.modelAutoRedeploySuccessRatio) {
            this.mlModelAutoReDeployer.redeployAModel();
            return true;
        }
        return false;
    }

    private void syncModelWorkerNodes(String modelId, FunctionName functionName) {
        DiscoveryNode[] allNodes = this.nodeHelper.getAllNodes();
        Object[] workerNodes = this.mlModelManager.getWorkerNodes(modelId, functionName);
        if (allNodes.length > 1 && workerNodes != null && workerNodes.length > 0) {
            log.debug("Sync to other nodes about worker nodes of model {}: {}", (Object)modelId, (Object)Arrays.toString(workerNodes));
            MLSyncUpInput syncUpInput = MLSyncUpInput.builder().addedWorkerNodes((Map)ImmutableMap.of((Object)modelId, (Object)workerNodes)).build();
            MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(allNodes, syncUpInput);
            this.client.execute((ActionType)MLSyncUpAction.INSTANCE, (ActionRequest)syncUpRequest, ActionListener.wrap(r -> log.debug("Sync up successfully"), e -> log.error("Failed to sync up", (Throwable)e)));
        }
    }
}

