/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.agent.tools;

import com.google.gson.reflect.TypeToken;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.Generated;
import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.admin.indices.get.GetIndexRequest;
import org.opensearch.action.support.IndicesOptions;
import org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest;
import org.opensearch.agent.tools.utils.ToolHelper;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.logging.LoggerMessageFormat;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.transport.client.Client;

@ToolAnnotation(value="CreateAlertTool")
public class CreateAlertTool
implements WithModelTool {
    @Generated
    private static final Logger log = LogManager.getLogger(CreateAlertTool.class);
    public static final String TYPE = "CreateAlertTool";
    private static final String DEFAULT_DESCRIPTION = "This is a tool that helps to create an alert(i.e. monitor with triggers), some parameters should be parsed based on user's question and context. The parameters should include: \n1. indices: The input indices of the monitor, should be a list of string in json format.\n";
    private String name = "CreateAlertTool";
    private String description = "This is a tool that helps to create an alert(i.e. monitor with triggers), some parameters should be parsed based on user's question and context. The parameters should include: \n1. indices: The input indices of the monitor, should be a list of string in json format.\n";
    private final Client client;
    private final String modelId;
    private final String modelType;
    private final String toolPrompt;
    private Map<String, Object> attributes;
    private static final String PROMPT_FILE_PATH = "CreateAlertDefaultPrompt.json";
    private static final String DEFAULT_QUESTION = "Create an alert as your recommendation based on the context";
    private static final Map<String, String> promptDict = ToolHelper.loadDefaultPromptDictFromFile(CreateAlertTool.class, "CreateAlertDefaultPrompt.json");

    public CreateAlertTool(Client client, String modelId, String modelType, String prompt) {
        this.client = client;
        this.modelId = modelId;
        this.modelType = String.valueOf((Object)ModelType.from(modelType));
        if (prompt.isEmpty()) {
            if (!promptDict.containsKey(this.modelType)) {
                throw new IllegalArgumentException(LoggerMessageFormat.format(null, (String)"Failed to find the right prompt for modelType: {}, this tool supports prompts for these models: [{}]", (Object[])new Object[]{modelType, String.join((CharSequence)",", promptDict.keySet())}));
            }
            this.toolPrompt = promptDict.get(this.modelType);
        } else {
            this.toolPrompt = prompt;
        }
    }

    public String getType() {
        return TYPE;
    }

    public String getVersion() {
        return null;
    }

    public boolean validate(Map<String, String> parameters) {
        return true;
    }

    public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
        HashMap<String, String> tmpParams = new HashMap<String, String>(parameters);
        if (!tmpParams.containsKey("indices") || Strings.isEmpty((CharSequence)((CharSequence)tmpParams.get("indices")))) {
            throw new IllegalArgumentException("No indices in the input parameter. Ask user to provide index as your final answer directly without using any other tools");
        }
        String rawIndex = tmpParams.getOrDefault("indices", "");
        Boolean isLocal = Boolean.parseBoolean(tmpParams.getOrDefault("local", "true"));
        GetIndexRequest getIndexRequest = CreateAlertTool.constructIndexRequest(rawIndex, isLocal);
        this.client.admin().indices().getIndex(getIndexRequest, ActionListener.wrap(response -> {
            if (response.indices().length == 0) {
                throw new IllegalArgumentException(LoggerMessageFormat.format(null, (String)"Cannot find provided indices {}. Ask user to check the provided indices as your final answer without using any other tools", (Object[])new Object[]{rawIndex}));
            }
            StringBuilder sb = new StringBuilder();
            for (String index : response.indices()) {
                sb.append("index: ").append(index).append("\n\n");
                MappingMetadata mapping = (MappingMetadata)response.mappings().get(index);
                if (mapping == null) continue;
                sb.append("mappings:\n");
                for (Map.Entry entry : mapping.sourceAsMap().entrySet()) {
                    sb.append((String)entry.getKey()).append("=").append(entry.getValue()).append('\n');
                }
                sb.append("\n\n");
            }
            String mappingInfo = sb.toString();
            ActionRequest request = this.constructMLPredictRequest(tmpParams, mappingInfo);
            this.client.execute((ActionType)MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(r -> {
                ModelTensorOutput modelTensorOutput = (ModelTensorOutput)r.getOutput();
                Map dataMap = Optional.ofNullable(modelTensorOutput.getMlModelOutputs()).flatMap(outputs -> outputs.stream().findFirst()).flatMap(modelTensors -> modelTensors.getMlModelTensors().stream().findFirst()).map(ModelTensor::getDataAsMap).orElse(null);
                if (dataMap == null) {
                    throw new IllegalArgumentException("No dataMap returned from LLM.");
                }
                String alertInfo = "";
                if (dataMap.containsKey("response")) {
                    alertInfo = (String)dataMap.get("response");
                    Pattern jsonPattern = Pattern.compile("```json(.*?)```", 32);
                    Matcher jsonBlockMatcher = jsonPattern.matcher(alertInfo);
                    if (jsonBlockMatcher.find()) {
                        alertInfo = jsonBlockMatcher.group(1);
                        alertInfo = alertInfo.replace("\\\"", "\"");
                    }
                } else {
                    alertInfo = StringUtils.toJson((Object)dataMap);
                }
                if (!StringUtils.isJson((String)alertInfo)) {
                    throw new IllegalArgumentException(LoggerMessageFormat.format(null, (String)"The response from LLM is not a json: [{}]", (Object[])new Object[]{alertInfo}));
                }
                listener.onResponse((Object)alertInfo);
            }, e -> {
                log.error("Failed to run model " + this.modelId, (Throwable)e);
                listener.onFailure(e);
            }));
        }, e -> {
            log.error("failed to get index mapping: " + String.valueOf(e));
            if (e.toString().contains("IndexNotFoundException")) {
                listener.onFailure((Exception)new IllegalArgumentException(LoggerMessageFormat.format(null, (String)"Cannot find provided indices {}. Ask user to check the provided indices as your final answer without using any other tools", (Object[])new Object[]{rawIndex})));
            } else {
                listener.onFailure(e);
            }
        }));
    }

    private ActionRequest constructMLPredictRequest(Map<String, String> tmpParams, String mappingInfo) {
        tmpParams.put("mapping_info", mappingInfo);
        tmpParams.putIfAbsent("indices", "");
        tmpParams.putIfAbsent("chat_history", "");
        tmpParams.putIfAbsent("question", DEFAULT_QUESTION);
        StringSubstitutor substitute = new StringSubstitutor(tmpParams, "${parameters.", "}");
        String finalToolPrompt = substitute.replace(this.toolPrompt);
        tmpParams.put("prompt", finalToolPrompt);
        RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(tmpParams).build();
        MLPredictionTaskRequest request = new MLPredictionTaskRequest(this.modelId, MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset((MLInputDataset)inputDataSet).build());
        return request;
    }

    private static GetIndexRequest constructIndexRequest(String rawIndex, Boolean isLocal) {
        List indexList;
        try {
            indexList = (List)StringUtils.gson.fromJson(rawIndex, new TypeToken<List<String>>(){}.getType());
        }
        catch (Exception e) {
            indexList = Arrays.asList(rawIndex.split("\\."));
        }
        if (indexList.isEmpty()) {
            throw new IllegalArgumentException("The input indices is empty. Ask user to provide index as your final answer directly without using any other tools");
        }
        if (indexList.stream().anyMatch(index -> index.startsWith("."))) {
            throw new IllegalArgumentException(LoggerMessageFormat.format(null, (String)"The provided indices [{}] contains system index, which is not allowed. Ask user to check the provided indices as your final answer without using any other.", (Object[])new Object[]{rawIndex}));
        }
        String[] indices = indexList.toArray(Strings.EMPTY_ARRAY);
        GetIndexRequest getIndexRequest = (GetIndexRequest)((GetIndexRequest)((GetIndexRequest)((GetIndexRequest)new GetIndexRequest().indices(indices)).indicesOptions(IndicesOptions.strictExpand())).local(isLocal.booleanValue())).clusterManagerNodeTimeout(ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT);
        return getIndexRequest;
    }

    @Generated
    public void setName(String name) {
        this.name = name;
    }

    @Generated
    public String getName() {
        return this.name;
    }

    @Generated
    public String getDescription() {
        return this.description;
    }

    @Generated
    public void setDescription(String description) {
        this.description = description;
    }

    @Generated
    public String getModelId() {
        return this.modelId;
    }

    @Generated
    public String getModelType() {
        return this.modelType;
    }

    @Generated
    public String getToolPrompt() {
        return this.toolPrompt;
    }

    @Generated
    public Map<String, Object> getAttributes() {
        return this.attributes;
    }

    @Generated
    public void setAttributes(Map<String, Object> attributes) {
        this.attributes = attributes;
    }

    public static enum ModelType {
        CLAUDE,
        OPENAI;


        public static ModelType from(String value) {
            if (value.isEmpty()) {
                return CLAUDE;
            }
            try {
                return ModelType.valueOf(value.toUpperCase(Locale.ROOT));
            }
            catch (Exception e) {
                log.error("Wrong Model type, should be CLAUDE or OPENAI");
                return CLAUDE;
            }
        }
    }

    public static class Factory
    implements WithModelTool.Factory<CreateAlertTool> {
        private Client client;
        private static Factory INSTANCE;

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public static Factory getInstance() {
            if (INSTANCE != null) {
                return INSTANCE;
            }
            Class<CreateAlertTool> clazz = CreateAlertTool.class;
            synchronized (CreateAlertTool.class) {
                if (INSTANCE != null) {
                    // ** MonitorExit[var0] (shouldn't be in output)
                    return INSTANCE;
                }
                INSTANCE = new Factory();
                // ** MonitorExit[var0] (shouldn't be in output)
                return INSTANCE;
            }
        }

        public void init(Client client) {
            this.client = client;
        }

        public CreateAlertTool create(Map<String, Object> params) {
            String modelId = (String)params.get("model_id");
            if (org.apache.commons.lang3.StringUtils.isBlank((CharSequence)modelId)) {
                throw new IllegalArgumentException("model_id cannot be null or blank.");
            }
            String modelType = (String)params.getOrDefault("model_type", ModelType.CLAUDE.toString());
            String prompt = (String)params.getOrDefault("prompt", "");
            return new CreateAlertTool(this.client, modelId, modelType, prompt);
        }

        public String getDefaultDescription() {
            return CreateAlertTool.DEFAULT_DESCRIPTION;
        }

        public String getDefaultType() {
            return CreateAlertTool.TYPE;
        }

        public String getDefaultVersion() {
            return null;
        }

        public List<String> getAllModelKeys() {
            return List.of("model_id");
        }
    }
}

