/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.deploy;

import com.google.common.annotations.VisibleForTesting;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.deploy.MLDeployModelInput;
import org.opensearch.ml.common.transport.deploy.MLDeployModelNodesRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelNodesResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelOnNodeAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

public class TransportDeployModelAction
extends HandledTransportAction<ActionRequest, MLDeployModelResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportDeployModelAction.class);
    TransportService transportService;
    ModelHelper modelHelper;
    MLTaskManager mlTaskManager;
    ClusterService clusterService;
    ThreadPool threadPool;
    Client client;
    SdkClient sdkClient;
    Settings settings;
    NamedXContentRegistry xContentRegistry;
    DiscoveryNodeHelper nodeFilter;
    MLTaskDispatcher mlTaskDispatcher;
    MLModelManager mlModelManager;
    MLStats mlStats;
    private volatile boolean allowCustomDeploymentPlan;
    private final ModelAccessControlHelper modelAccessControlHelper;
    private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

    @Inject
    public TransportDeployModelAction(TransportService transportService, ActionFilters actionFilters, ModelHelper modelHelper, MLTaskManager mlTaskManager, ClusterService clusterService, ThreadPool threadPool, Client client, SdkClient sdkClient, NamedXContentRegistry xContentRegistry, DiscoveryNodeHelper nodeFilter, MLTaskDispatcher mlTaskDispatcher, MLModelManager mlModelManager, MLStats mlStats, Settings settings, ModelAccessControlHelper modelAccessControlHelper, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
        super("cluster:admin/opensearch/ml/deploy_model", transportService, actionFilters, MLDeployModelRequest::new);
        this.transportService = transportService;
        this.modelHelper = modelHelper;
        this.mlTaskManager = mlTaskManager;
        this.clusterService = clusterService;
        this.threadPool = threadPool;
        this.client = client;
        this.sdkClient = sdkClient;
        this.xContentRegistry = xContentRegistry;
        this.nodeFilter = nodeFilter;
        this.mlTaskDispatcher = mlTaskDispatcher;
        this.mlModelManager = mlModelManager;
        this.mlStats = mlStats;
        this.modelAccessControlHelper = modelAccessControlHelper;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
        this.settings = settings;
        this.allowCustomDeploymentPlan = (Boolean)MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN, it -> {
            this.allowCustomDeploymentPlan = it;
        });
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLDeployModelResponse> listener) {
        MLDeployModelRequest deployModelRequest = MLDeployModelRequest.fromActionRequest((ActionRequest)request);
        String modelId = deployModelRequest.getModelId();
        String tenantId = deployModelRequest.getTenantId();
        if (!TenantAwareHelper.validateTenantId(this.mlFeatureEnabledSetting, tenantId, listener)) {
            return;
        }
        boolean isUserInitiatedDeployRequest = deployModelRequest.isUserInitiatedDeployRequest();
        User user = RestActionUtils.getUserContext(this.client);
        boolean isSuperAdmin = this.isSuperAdminUserWrapper(this.clusterService, this.client);
        String[] excludes = new String[]{"model_content", "content"};
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener wrappedListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            this.mlModelManager.getModel(modelId, tenantId, null, excludes, (ActionListener<MLModel>)ActionListener.wrap(mlModel -> {
                FunctionName functionName = mlModel.getAlgorithm();
                Boolean isHidden = mlModel.getIsHidden();
                if (!TenantAwareHelper.validateTenantResource(this.mlFeatureEnabledSetting, tenantId, mlModel.getTenantId(), listener)) {
                    return;
                }
                if (functionName == FunctionName.REMOTE && !this.mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
                    throw new IllegalStateException("Remote Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.remote_inference_enabled\" to true.");
                }
                if (FunctionName.isDLModel((FunctionName)functionName) && !this.mlFeatureEnabledSetting.isLocalModelEnabled()) {
                    throw new IllegalStateException("Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true.");
                }
                if (!isUserInitiatedDeployRequest) {
                    this.deployModel(deployModelRequest, (MLModel)mlModel, modelId, tenantId, (ActionListener<MLDeployModelResponse>)wrappedListener, listener);
                } else if (Boolean.TRUE.equals(isHidden)) {
                    if (isSuperAdmin) {
                        this.deployModel(deployModelRequest, (MLModel)mlModel, modelId, tenantId, (ActionListener<MLDeployModelResponse>)wrappedListener, listener);
                    } else {
                        wrappedListener.onFailure((Exception)new OpenSearchStatusException("User doesn't have privilege to perform this operation on this model", RestStatus.FORBIDDEN, new Object[0]));
                    }
                } else {
                    this.modelAccessControlHelper.validateModelGroupAccess(user, mlModel.getModelGroupId(), this.client, (ActionListener<Boolean>)ActionListener.wrap(access -> {
                        if (!access.booleanValue()) {
                            wrappedListener.onFailure((Exception)new OpenSearchStatusException("User doesn't have privilege to perform this operation on this model", RestStatus.FORBIDDEN, new Object[0]));
                        } else {
                            this.deployModel(deployModelRequest, (MLModel)mlModel, modelId, tenantId, (ActionListener<MLDeployModelResponse>)wrappedListener, listener);
                        }
                    }, e -> {
                        log.error(StringUtils.getErrorMessage((String)"Failed to Validate Access for the given model", (String)modelId, (Boolean)isHidden), (Throwable)e);
                        wrappedListener.onFailure(e);
                    }));
                }
            }, e -> {
                log.error("Failed to retrieve the ML model with the given ID", (Throwable)e);
                wrappedListener.onFailure(e);
            }));
        }
        catch (Exception e2) {
            log.error("Failed to deploy the ML model", (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    private void deployModel(MLDeployModelRequest deployModelRequest, MLModel mlModel, String modelId, String tenantId, ActionListener<MLDeployModelResponse> wrappedListener, ActionListener<MLDeployModelResponse> listener) {
        boolean deployToAllNodes;
        String[] targetNodeIds = deployModelRequest.getModelNodeIds();
        boolean bl = deployToAllNodes = targetNodeIds == null || targetNodeIds.length == 0;
        if (!this.allowCustomDeploymentPlan && !deployToAllNodes) {
            throw new IllegalArgumentException("Don't allow custom deployment plan");
        }
        DiscoveryNode[] allEligibleNodes = this.nodeFilter.getEligibleNodes(mlModel.getAlgorithm());
        HashMap<String, DiscoveryNode> nodeMapping = new HashMap<String, DiscoveryNode>();
        for (DiscoveryNode discoveryNode : allEligibleNodes) {
            nodeMapping.put(discoveryNode.getId(), discoveryNode);
        }
        Set allEligibleNodeIds = Arrays.stream(allEligibleNodes).map(DiscoveryNode::getId).collect(Collectors.toSet());
        ArrayList<DiscoveryNode> eligibleNodes = new ArrayList<DiscoveryNode>();
        ArrayList<String> eligibleNodeIds = new ArrayList<String>();
        if (!deployToAllNodes) {
            for (String nodeId : targetNodeIds) {
                if (!allEligibleNodeIds.contains(nodeId)) continue;
                eligibleNodes.add((DiscoveryNode)nodeMapping.get(nodeId));
                eligibleNodeIds.add(nodeId);
            }
            String[] stringArray = this.mlModelManager.getWorkerNodes(modelId, mlModel.getAlgorithm());
            if (stringArray != null && stringArray.length > 0) {
                HashSet<String> difference = new HashSet<String>(Arrays.asList(stringArray));
                Arrays.asList(targetNodeIds).forEach(difference::remove);
                if (!difference.isEmpty()) {
                    wrappedListener.onFailure((Exception)new IllegalArgumentException("Model already deployed to these nodes: " + Arrays.toString(difference.toArray(new String[0])) + ", but they are not included in target node ids. Undeploy model from these nodes if don't need them any more.Undeploy from old nodes before try to deploy model on new nodes. Or include all old nodes on your target nodes."));
                    return;
                }
            }
        } else {
            eligibleNodeIds.addAll(allEligibleNodeIds);
            eligibleNodes.addAll(Arrays.asList(allEligibleNodes));
        }
        if (eligibleNodeIds.isEmpty()) {
            wrappedListener.onFailure((Exception)new IllegalArgumentException("no eligible node found"));
            return;
        }
        log.info("Will deploy model on these nodes: {}", (Object)String.join((CharSequence)",", eligibleNodeIds));
        String string = this.clusterService.localNode().getId();
        FunctionName algorithm = mlModel.getAlgorithm();
        MLTask mlTask = MLTask.builder().async(true).modelId(modelId).taskType(MLTaskType.DEPLOY_MODEL).functionName(algorithm).createTime(Instant.now()).lastUpdateTime(Instant.now()).state(MLTaskState.CREATED).workerNodes(eligibleNodeIds).tenantId(tenantId).build();
        this.mlTaskManager.createMLTask(mlTask, (ActionListener<IndexResponse>)ActionListener.wrap(response -> {
            String taskId = response.getId();
            mlTask.setTaskId(taskId);
            if (algorithm == FunctionName.REMOTE) {
                this.mlTaskManager.add(mlTask, eligibleNodeIds);
                this.deployRemoteModel(mlModel, mlTask, localNodeId, eligibleNodes, deployToAllNodes, listener);
                return;
            }
            try {
                this.mlTaskManager.add(mlTask, eligibleNodeIds);
                wrappedListener.onResponse((Object)new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.CREATED.name()));
                this.threadPool.executor("opensearch_ml_deploy").execute(() -> this.updateModelDeployStatusAndTriggerOnNodesAction(modelId, taskId, tenantId, mlModel, localNodeId, mlTask, eligibleNodes, deployToAllNodes));
            }
            catch (Exception ex) {
                log.error("Failed to deploy model", (Throwable)ex);
                this.mlTaskManager.updateMLTask(taskId, tenantId, Map.of("state", MLTaskState.FAILED, "error", MLExceptionUtils.getRootCauseMessage(ex)), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
                wrappedListener.onFailure(ex);
            }
        }, exception -> {
            if (mlModel.getIsHidden().booleanValue()) {
                log.error("Failed to create deploy model task for the provided model", (Throwable)exception);
            } else {
                log.error("Failed to create deploy model task for {}", (Object)modelId, exception);
            }
            wrappedListener.onFailure(exception);
        }));
    }

    @VisibleForTesting
    void deployRemoteModel(MLModel mlModel, MLTask mlTask, String localNodeId, List<DiscoveryNode> eligibleNodes, boolean deployToAllNodes, ActionListener<MLDeployModelResponse> listener) {
        MLDeployModelInput deployModelInput = new MLDeployModelInput(mlModel.getModelId(), mlTask.getTaskId(), mlModel.getModelContentHash(), Integer.valueOf(eligibleNodes.size()), localNodeId, Boolean.valueOf(deployToAllNodes), mlTask, mlModel.getTenantId());
        MLDeployModelNodesRequest deployModelRequest = new MLDeployModelNodesRequest(eligibleNodes.toArray(new DiscoveryNode[0]), deployModelInput);
        ActionListener<MLDeployModelNodesResponse> actionListener = this.deployModelNodesResponseListener(mlTask.getTaskId(), mlModel.getModelId(), mlModel.getTenantId(), mlModel.getIsHidden(), listener);
        List workerNodes = eligibleNodes.stream().map(DiscoveryNode::getId).collect(Collectors.toList());
        this.mlModelManager.updateModel(mlModel.getModelId(), mlModel.getTenantId(), Map.of("model_state", MLModelState.DEPLOYING, "planning_worker_node_count", eligibleNodes.size(), "planning_worker_nodes", workerNodes, "deploy_to_all_nodes", deployToAllNodes), (ActionListener<UpdateResponse>)ActionListener.wrap(r -> this.client.execute((ActionType)MLDeployModelOnNodeAction.INSTANCE, (ActionRequest)deployModelRequest, actionListener), arg_0 -> actionListener.onFailure(arg_0)));
    }

    private ActionListener<MLDeployModelNodesResponse> deployModelNodesResponseListener(String taskId, String modelId, String tenantId, Boolean isHidden, ActionListener<MLDeployModelResponse> listener) {
        return ActionListener.wrap(r -> {
            if (this.mlTaskManager.contains(taskId)) {
                this.mlTaskManager.updateMLTask(taskId, tenantId, Map.of("state", MLTaskState.RUNNING), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, false);
            }
            listener.onResponse((Object)new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.COMPLETED.name()));
        }, e -> {
            log.error("Failed to deploy model {}", (Object)modelId, e);
            this.mlTaskManager.updateMLTask(taskId, tenantId, Map.of("error", MLExceptionUtils.getRootCauseMessage(e), "state", MLTaskState.FAILED), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
            this.mlModelManager.updateModel(modelId, tenantId, isHidden, Map.of("model_state", MLModelState.DEPLOY_FAILED));
            listener.onFailure(e);
        });
    }

    @VisibleForTesting
    void updateModelDeployStatusAndTriggerOnNodesAction(String modelId, String taskId, String tenantId, MLModel mlModel, String localNodeId, MLTask mlTask, List<DiscoveryNode> eligibleNodes, boolean deployToAllNodes) {
        MLDeployModelInput deployModelInput = new MLDeployModelInput(modelId, taskId, mlModel.getModelContentHash(), Integer.valueOf(eligibleNodes.size()), localNodeId, Boolean.valueOf(deployToAllNodes), mlTask, tenantId);
        MLDeployModelNodesRequest deployModelRequest = new MLDeployModelNodesRequest(eligibleNodes.toArray(new DiscoveryNode[0]), deployModelInput);
        ActionListener actionListener = ActionListener.wrap(r -> {
            if (this.mlTaskManager.contains(taskId)) {
                this.mlTaskManager.updateMLTask(taskId, mlModel.getTenantId(), Map.of("state", MLTaskState.RUNNING), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, false);
            }
        }, e -> {
            log.error("Failed to deploy model {}", (Object)modelId, e);
            this.mlTaskManager.updateMLTask(taskId, mlModel.getTenantId(), Map.of("error", MLExceptionUtils.getRootCauseMessage(e), "state", MLTaskState.FAILED), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
            this.mlModelManager.updateModel(modelId, mlModel.getTenantId(), mlModel.getIsHidden(), Map.of("model_state", MLModelState.DEPLOY_FAILED));
        });
        List<String> workerNodes = eligibleNodes.stream().map(DiscoveryNode::getId).toList();
        this.mlModelManager.updateModel(modelId, mlModel.getTenantId(), Map.of("model_state", MLModelState.DEPLOYING, "planning_worker_node_count", eligibleNodes.size(), "planning_worker_nodes", workerNodes, "deploy_to_all_nodes", deployToAllNodes), (ActionListener<UpdateResponse>)ActionListener.wrap(r -> this.client.execute((ActionType)MLDeployModelOnNodeAction.INSTANCE, (ActionRequest)deployModelRequest, actionListener), arg_0 -> ((ActionListener)actionListener).onFailure(arg_0)));
    }

    @VisibleForTesting
    boolean isSuperAdminUserWrapper(ClusterService clusterService, Client client) {
        return RestActionUtils.isSuperAdminUser(clusterService, client);
    }
}

