/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.agent.tools;

import com.google.gson.Gson;
import com.google.gson.JsonObject;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.commons.lang3.StringEscapeUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.agent.tools.AbstractRetrieverTool;
import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.VectorDBTool;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.transport.client.Client;

@ToolAnnotation(value="RAGTool")
public class RAGTool
implements WithModelTool {
    @Generated
    private static final Logger log = LogManager.getLogger(RAGTool.class);
    public static final String TYPE = "RAGTool";
    public static String DEFAULT_DESCRIPTION = "Use this tool to retrieve helpful information to optimize the output of the large language model to answer questions.";
    public static final String INFERENCE_MODEL_ID_FIELD = "inference_model_id";
    public static final String EMBEDDING_MODEL_ID_FIELD = "embedding_model_id";
    public static final String INDEX_FIELD = "index";
    public static final String SOURCE_FIELD = "source_field";
    public static final String DOC_SIZE_FIELD = "doc_size";
    public static final String EMBEDDING_FIELD = "embedding_field";
    public static final String OUTPUT_FIELD = "output_field";
    public static final String QUERY_TYPE = "query_type";
    public static final String CONTENT_GENERATION_FIELD = "enable_content_generation";
    public static final String K_FIELD = "k";
    private final AbstractRetrieverTool queryTool;
    private String name = "RAGTool";
    private String description = DEFAULT_DESCRIPTION;
    private Client client;
    private String inferenceModelId;
    private Boolean enableContentGeneration;
    private NamedXContentRegistry xContentRegistry;
    private String queryType;
    private Parser inputParser;
    private Parser outputParser;
    private Map<String, Object> attributes;

    public RAGTool(Client client, NamedXContentRegistry xContentRegistry, String inferenceModelId, Boolean enableContentGeneration, AbstractRetrieverTool queryTool) {
        this.client = client;
        this.xContentRegistry = xContentRegistry;
        this.inferenceModelId = inferenceModelId;
        this.enableContentGeneration = enableContentGeneration;
        this.queryTool = queryTool;
        this.outputParser = new Parser(this){

            public Object parse(Object o) {
                List mlModelOutputs = (List)o;
                return ((ModelTensor)((ModelTensors)mlModelOutputs.get(0)).getMlModelTensors().get(0)).getDataAsMap().get("response");
            }
        };
    }

    public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
        String tenantId = parameters.get("tenant_id");
        String input = null;
        if (!this.validate(parameters)) {
            throw new IllegalArgumentException("[input] is null or empty, can not process it.");
        }
        try {
            input = parameters.get("input");
        }
        catch (Exception e2) {
            log.error("Failed to read question from input", (Throwable)e2);
            listener.onFailure((Exception)new IllegalArgumentException("Failed to read question from input"));
            return;
        }
        String embeddingInput = input;
        ActionListener actionListener = ActionListener.wrap(r -> {
            String queryToolOutput;
            if (!this.enableContentGeneration.booleanValue()) {
                listener.onResponse(r);
                return;
            }
            if (r.equals("Can not get any match from search result.")) {
                queryToolOutput = "";
            } else {
                Gson gson = new Gson();
                String[] hits = r.toString().split("\n");
                StringBuilder resultBuilder = new StringBuilder();
                for (String hit : hits) {
                    JsonObject jsonObject = (JsonObject)gson.fromJson(hit, JsonObject.class);
                    String id = jsonObject.get("_id").getAsString();
                    JsonObject source = jsonObject.getAsJsonObject("_source");
                    resultBuilder.append("_id: ").append(id).append("\n");
                    resultBuilder.append("_source: ").append(source.toString()).append("\n");
                }
                queryToolOutput = gson.toJson((Object)resultBuilder.toString());
            }
            HashMap<String, String> tmpParameters = new HashMap<String, String>();
            tmpParameters.putAll(parameters);
            if (queryToolOutput instanceof String) {
                tmpParameters.put(OUTPUT_FIELD, queryToolOutput);
            } else {
                tmpParameters.put(OUTPUT_FIELD, StringEscapeUtils.escapeJson((String)StringUtils.toJson((Object)queryToolOutput.toString())));
            }
            RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build();
            MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset((MLInputDataset)inputDataSet).build();
            MLPredictionTaskRequest request = new MLPredictionTaskRequest(this.inferenceModelId, mlInput, null, tenantId);
            this.client.execute((ActionType)MLPredictionTaskAction.INSTANCE, (ActionRequest)request, ActionListener.wrap(resp -> {
                ModelTensorOutput modelTensorOutput = (ModelTensorOutput)resp.getOutput();
                modelTensorOutput.getMlModelOutputs();
                if (this.outputParser == null) {
                    listener.onResponse((Object)modelTensorOutput.getMlModelOutputs());
                } else {
                    listener.onResponse(this.outputParser.parse((Object)modelTensorOutput.getMlModelOutputs()));
                }
            }, e -> {
                log.error("Failed to run model " + this.inferenceModelId, (Throwable)e);
                listener.onFailure(e);
            }));
        }, e -> {
            log.error("Failed to search index.", (Throwable)e);
            listener.onFailure(e);
        });
        this.queryTool.run(Map.of("input", embeddingInput), actionListener);
    }

    public String getType() {
        return TYPE;
    }

    public String getVersion() {
        return null;
    }

    public String getName() {
        return this.name;
    }

    public void setName(String s) {
        this.name = s;
    }

    public boolean validate(Map<String, String> parameters) {
        if (parameters == null || parameters.size() == 0) {
            return false;
        }
        String question = parameters.get("input");
        return question != null && !question.trim().isEmpty();
    }

    @Generated
    public static RAGToolBuilder builder() {
        return new RAGToolBuilder();
    }

    @Generated
    public void setDescription(String description) {
        this.description = description;
    }

    @Generated
    public void setClient(Client client) {
        this.client = client;
    }

    @Generated
    public void setInferenceModelId(String inferenceModelId) {
        this.inferenceModelId = inferenceModelId;
    }

    @Generated
    public void setEnableContentGeneration(Boolean enableContentGeneration) {
        this.enableContentGeneration = enableContentGeneration;
    }

    @Generated
    public void setXContentRegistry(NamedXContentRegistry xContentRegistry) {
        this.xContentRegistry = xContentRegistry;
    }

    @Generated
    public void setQueryType(String queryType) {
        this.queryType = queryType;
    }

    @Generated
    public void setAttributes(Map<String, Object> attributes) {
        this.attributes = attributes;
    }

    @Generated
    public AbstractRetrieverTool getQueryTool() {
        return this.queryTool;
    }

    @Generated
    public String getDescription() {
        return this.description;
    }

    @Generated
    public Client getClient() {
        return this.client;
    }

    @Generated
    public String getInferenceModelId() {
        return this.inferenceModelId;
    }

    @Generated
    public Boolean getEnableContentGeneration() {
        return this.enableContentGeneration;
    }

    @Generated
    public NamedXContentRegistry getXContentRegistry() {
        return this.xContentRegistry;
    }

    @Generated
    public String getQueryType() {
        return this.queryType;
    }

    @Generated
    public Parser getInputParser() {
        return this.inputParser;
    }

    @Generated
    public Parser getOutputParser() {
        return this.outputParser;
    }

    @Generated
    public Map<String, Object> getAttributes() {
        return this.attributes;
    }

    @Generated
    public void setInputParser(Parser inputParser) {
        this.inputParser = inputParser;
    }

    @Generated
    public void setOutputParser(Parser outputParser) {
        this.outputParser = outputParser;
    }

    @Generated
    public static class RAGToolBuilder {
        @Generated
        private Client client;
        @Generated
        private NamedXContentRegistry xContentRegistry;
        @Generated
        private String inferenceModelId;
        @Generated
        private Boolean enableContentGeneration;
        @Generated
        private AbstractRetrieverTool queryTool;

        @Generated
        RAGToolBuilder() {
        }

        @Generated
        public RAGToolBuilder client(Client client) {
            this.client = client;
            return this;
        }

        @Generated
        public RAGToolBuilder xContentRegistry(NamedXContentRegistry xContentRegistry) {
            this.xContentRegistry = xContentRegistry;
            return this;
        }

        @Generated
        public RAGToolBuilder inferenceModelId(String inferenceModelId) {
            this.inferenceModelId = inferenceModelId;
            return this;
        }

        @Generated
        public RAGToolBuilder enableContentGeneration(Boolean enableContentGeneration) {
            this.enableContentGeneration = enableContentGeneration;
            return this;
        }

        @Generated
        public RAGToolBuilder queryTool(AbstractRetrieverTool queryTool) {
            this.queryTool = queryTool;
            return this;
        }

        @Generated
        public RAGTool build() {
            return new RAGTool(this.client, this.xContentRegistry, this.inferenceModelId, this.enableContentGeneration, this.queryTool);
        }

        @Generated
        public String toString() {
            return "RAGTool.RAGToolBuilder(client=" + String.valueOf(this.client) + ", xContentRegistry=" + String.valueOf(this.xContentRegistry) + ", inferenceModelId=" + this.inferenceModelId + ", enableContentGeneration=" + this.enableContentGeneration + ", queryTool=" + String.valueOf(this.queryTool) + ")";
        }
    }

    public static class Factory
    implements WithModelTool.Factory<RAGTool> {
        private Client client;
        private NamedXContentRegistry xContentRegistry;
        private static Factory INSTANCE;

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public static Factory getInstance() {
            if (INSTANCE != null) {
                return INSTANCE;
            }
            Class<RAGTool> clazz = RAGTool.class;
            synchronized (RAGTool.class) {
                if (INSTANCE != null) {
                    // ** MonitorExit[var0] (shouldn't be in output)
                    return INSTANCE;
                }
                INSTANCE = new Factory();
                // ** MonitorExit[var0] (shouldn't be in output)
                return INSTANCE;
            }
        }

        public void init(Client client, NamedXContentRegistry xContentRegistry) {
            this.client = client;
            this.xContentRegistry = xContentRegistry;
        }

        public RAGTool create(Map<String, Object> params) {
            String queryType = params.containsKey(RAGTool.QUERY_TYPE) ? (String)params.get(RAGTool.QUERY_TYPE) : "neural";
            String embeddingModelId = (String)params.get(RAGTool.EMBEDDING_MODEL_ID_FIELD);
            Boolean enableContentGeneration = params.containsKey(RAGTool.CONTENT_GENERATION_FIELD) ? Boolean.parseBoolean((String)params.get(RAGTool.CONTENT_GENERATION_FIELD)) : true;
            String inferenceModelId = enableContentGeneration != false ? (String)params.get(RAGTool.INFERENCE_MODEL_ID_FIELD) : "";
            switch (queryType) {
                case "neural_sparse": {
                    params.put("model_id", embeddingModelId);
                    Tool neuralSparseSearchTool = NeuralSparseSearchTool.Factory.getInstance().create((Map)params);
                    return RAGTool.builder().client(this.client).xContentRegistry(this.xContentRegistry).inferenceModelId(inferenceModelId).enableContentGeneration(enableContentGeneration).queryTool((AbstractRetrieverTool)neuralSparseSearchTool).build();
                }
                case "neural": {
                    params.put("model_id", embeddingModelId);
                    Tool vectorDBTool = VectorDBTool.Factory.getInstance().create((Map)params);
                    return RAGTool.builder().client(this.client).xContentRegistry(this.xContentRegistry).inferenceModelId(inferenceModelId).enableContentGeneration(enableContentGeneration).queryTool((AbstractRetrieverTool)vectorDBTool).build();
                }
            }
            log.error("Failed to read queryType, please input neural_sparse or neural.");
            throw new IllegalArgumentException("Failed to read queryType, please input neural_sparse or neural.");
        }

        public String getDefaultDescription() {
            return DEFAULT_DESCRIPTION;
        }

        public String getDefaultType() {
            return RAGTool.TYPE;
        }

        public String getDefaultVersion() {
            return null;
        }

        public List<String> getAllModelKeys() {
            return List.of(RAGTool.INFERENCE_MODEL_ID_FIELD, RAGTool.EMBEDDING_MODEL_ID_FIELD);
        }
    }
}

