/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.agent;

import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Predicate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import joptsimple.internal.Strings;
import lombok.Generated;
import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.StepListener;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.agent.LLMSpec;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.memory.Memory;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.algorithms.agent.AgentUtils;
import org.opensearch.ml.engine.algorithms.agent.MLAgentRunner;
import org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner;
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.transport.client.Client;

public class MLPlanExecuteAndReflectAgentRunner
implements MLAgentRunner {
    @Generated
    private static final Logger log = LogManager.getLogger(MLPlanExecuteAndReflectAgentRunner.class);
    private final Client client;
    private final Settings settings;
    private final ClusterService clusterService;
    private final NamedXContentRegistry xContentRegistry;
    private final Map<String, Tool.Factory> toolFactories;
    private final Map<String, Memory.Factory> memoryFactoryMap;
    private static final String DEFAULT_DEEP_RESEARCH_SYSTEM_PROMPT = "Always respond in JSON format.";
    private static final String DEFAULT_REACT_SYSTEM_PROMPT = "You are a helpful assistant.";
    private static final String DEFAULT_NO_ESCAPE_PARAMS = "tool_configs,_tools";
    private static final String DEFAULT_MAX_STEPS_EXECUTED = "20";
    private static final int DEFAULT_MESSAGE_HISTORY_LIMIT = 10;
    public static final String PROMPT_FIELD = "prompt";
    public static final String USER_PROMPT_FIELD = "user_prompt";
    public static final String REACT_SYSTEM_PROMPT_FIELD = "react_system_prompt";
    public static final String STEPS_FIELD = "steps";
    public static final String COMPLETED_STEPS_FIELD = "completed_steps";
    public static final String PLANNER_PROMPT_FIELD = "planner_prompt";
    public static final String REVAL_PROMPT_FIELD = "reval_prompt";
    public static final String DEEP_RESEARCH_RESPONSE_FORMAT_FIELD = "deep_research_response_format";
    public static final String PROMPT_TEMPLATE_FIELD = "prompt_template";
    public static final String SYSTEM_PROMPT_FIELD = "system_prompt";
    public static final String QUESTION_FIELD = "question";
    public static final String MEMORY_ID_FIELD = "memory_id";
    public static final String TENANT_ID_FIELD = "tenant_id";
    public static final String RESULT_FIELD = "result";
    public static final String RESPONSE_FIELD = "response";
    public static final String STEP_RESULT_FIELD = "step_result";
    public static final String REACT_AGENT_ID_FIELD = "reAct_agent_id";
    public static final String REACT_AGENT_MEMORY_ID_FIELD = "reAct_agent_memory_id";
    public static final String NO_ESCAPE_PARAMS_FIELD = "no_escape_params";
    public static final String DEFAULT_PROMPT_TOOLS_FIELD = "tools_prompt";
    public static final String MAX_STEPS_EXECUTED_FIELD = "max_steps";

    public MLPlanExecuteAndReflectAgentRunner(Client client, Settings settings, ClusterService clusterService, NamedXContentRegistry registry, Map<String, Tool.Factory> toolFactories, Map<String, Memory.Factory> memoryFactoryMap) {
        this.client = client;
        this.settings = settings;
        this.clusterService = clusterService;
        this.xContentRegistry = registry;
        this.toolFactories = toolFactories;
        this.memoryFactoryMap = memoryFactoryMap;
    }

    private void setupPromptParameters(Map<String, String> params) {
        params.remove(PROMPT_FIELD);
        String userPrompt = params.get(QUESTION_FIELD);
        params.put(USER_PROMPT_FIELD, userPrompt);
        String userSystemPrompt = params.getOrDefault(SYSTEM_PROMPT_FIELD, "");
        params.put(SYSTEM_PROMPT_FIELD, userSystemPrompt + DEFAULT_DEEP_RESEARCH_SYSTEM_PROMPT);
        params.put(PLANNER_PROMPT_FIELD, "For the given objective, come up with a simple step by step plan. This plan should involve individual tasks, that if executed correctly will yield the correct answer. Do not add any superfluous steps. The result of the final step should be the final answer. Make sure that each step has all the information needed - do not skip steps. At all costs, do not execute the steps. You will be told when to execute the steps.");
        params.put(REVAL_PROMPT_FIELD, "Update your plan accordingly. If no more steps are needed and you can return to the user, then respond with that. Otherwise, fill out the plan. Only add steps to the plan that still NEED to be done. Do not return previously done steps as part of the plan. Please follow the below response format.");
        params.put(DEEP_RESEARCH_RESPONSE_FORMAT_FIELD, "${parameters.tools_prompt:-} \nResponse Instructions: \nALWAYS follow the given response instructions. Do not return any content that does not follow the response instructions. Do not add anything before or after the expected JSON \nAlways respond with a valid JSON object that strictly follows the below schema:\n{\n\t\"steps\": array[string], \n\t\"result\": string \n}\nUse \"steps\" to return an array of strings where each string is a step to complete the objective, leave it empty if you know the final result. Please wrap each step in quotes and escape any special characters within the string. \nUse \"result\" return the final response when you have enough information, leave it empty if you want to execute more steps \nHere are examples of valid responses:\n\nExample 1 - When you need to execute steps:\n{\n\t\"steps\": [\"Search for logs containing error messages in the last hour\", \"Analyze the frequency of each error type\", \"Check system metrics during error spikes\"],\n\t\"result\": \"\"\n}\n\nExample 2 - When you have the final result:\n{\n\t\"steps\": [],\n\t\"result\": \"Based on the analysis, the root cause of the system slowdown was a memory leak in the authentication service, which started at 14:30 UTC.\"\n}\nIMPORTANT RULES:\n1. DO NOT use commas within individual steps \n2. DO NOT add any content before or after the JSON \n3. ONLY respond with a pure JSON object \n4. DO NOT USE ANY TOOLS. TOOLS ARE PROVIDED ONLY FOR YOU TO MAKE A PLAN.");
        params.put(NO_ESCAPE_PARAMS_FIELD, DEFAULT_NO_ESCAPE_PARAMS);
        if (params.containsKey("_llm_interface") && (!params.containsKey("llm_response_filter") || params.get("llm_response_filter").isEmpty())) {
            String llmInterface = params.get("_llm_interface");
            String llmResponseFilter = switch (llmInterface.trim().toLowerCase(Locale.ROOT)) {
                case "bedrock/converse/claude", "bedrock/converse/deepseek_r1" -> "$.output.message.content[0].text";
                case "openai/v1/chat/completions" -> "$.choices[0].message.content";
                default -> throw new MLException(String.format("Unsupported llm interface: %s", llmInterface));
            };
            params.put("llm_response_filter", llmResponseFilter);
        }
    }

    private void usePlannerPromptTemplate(Map<String, String> params) {
        params.put(PROMPT_TEMPLATE_FIELD, "${parameters.planner_prompt} \nObjective: ${parameters.user_prompt} \n\n${parameters.deep_research_response_format}");
        this.populatePrompt(params);
    }

    private void useRevalPromptTemplate(Map<String, String> params) {
        params.put(PROMPT_TEMPLATE_FIELD, "${parameters.planner_prompt} \n\nObjective: ${parameters.user_prompt} \n\nOriginal plan:\n[${parameters.steps}] \n\nYou have currently executed the following steps: \n[${parameters.completed_steps}] \n\n${parameters.reval_prompt} \n\n${parameters.deep_research_response_format}");
        this.populatePrompt(params);
    }

    private void usePlannerWithHistoryPromptTemplate(Map<String, String> params) {
        params.put(PROMPT_TEMPLATE_FIELD, "${parameters.planner_prompt} \nObjective: ${parameters.user_prompt} \n\nYou have currently executed the following steps: \n[${parameters.completed_steps}] \n\n${parameters.deep_research_response_format}");
        this.populatePrompt(params);
    }

    private void populatePrompt(Map<String, String> allParams) {
        String promptTemplate = allParams.get(PROMPT_TEMPLATE_FIELD);
        StringSubstitutor promptSubstitutor = new StringSubstitutor(allParams, "${parameters.", "}");
        String prompt = promptSubstitutor.replace(promptTemplate);
        allParams.put(PROMPT_FIELD, prompt);
    }

    @Override
    public void run(MLAgent mlAgent, Map<String, String> apiParams, ActionListener<Object> listener) {
        HashMap<String, String> allParams = new HashMap<String, String>();
        allParams.putAll(apiParams);
        allParams.putAll(mlAgent.getParameters());
        this.setupPromptParameters(allParams);
        this.usePlannerPromptTemplate(allParams);
        String memoryId = (String)allParams.get(MEMORY_ID_FIELD);
        String memoryType = mlAgent.getMemory().getType();
        String appType = mlAgent.getAppType();
        int messageHistoryLimit = 10;
        ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory)this.memoryFactoryMap.get(memoryType);
        conversationIndexMemoryFactory.create(apiParams.get(USER_PROMPT_FIELD), memoryId, appType, (ActionListener<ConversationIndexMemory>)ActionListener.wrap(memory -> memory.getMessages(ActionListener.wrap(interactions -> {
            ArrayList<String> completedSteps = new ArrayList<String>();
            for (Interaction interaction : interactions) {
                String question = interaction.getInput();
                String response = interaction.getResponse();
                if (Strings.isNullOrEmpty((String)response)) continue;
                completedSteps.add(question);
                completedSteps.add(response);
            }
            if (!completedSteps.isEmpty()) {
                this.addSteps(completedSteps, allParams, COMPLETED_STEPS_FIELD);
                this.usePlannerWithHistoryPromptTemplate(allParams);
            }
            this.setToolsAndRunAgent(mlAgent, (Map<String, String>)allParams, (List<String>)completedSteps, (Memory)memory, memory.getConversationId(), listener);
        }, e -> {
            log.error("Failed to get chat history", (Throwable)e);
            listener.onFailure(e);
        }), messageHistoryLimit), arg_0 -> listener.onFailure(arg_0)));
    }

    private void setToolsAndRunAgent(MLAgent mlAgent, Map<String, String> allParams, List<String> completedSteps, Memory memory, String conversationId, ActionListener<Object> finalListener) {
        List<MLToolSpec> toolSpecs = AgentUtils.getMlToolSpecs(mlAgent, allParams);
        HashMap<String, Tool> tools = new HashMap<String, Tool>();
        HashMap<String, MLToolSpec> toolSpecMap = new HashMap<String, MLToolSpec>();
        AgentUtils.createTools(this.toolFactories, allParams, toolSpecs, tools, toolSpecMap, mlAgent);
        this.addToolsToPrompt(tools, allParams);
        AtomicInteger traceNumber = new AtomicInteger(0);
        this.executePlanningLoop(mlAgent.getLlm(), allParams, completedSteps, memory, conversationId, 0, traceNumber, finalListener);
    }

    private void executePlanningLoop(LLMSpec llm, Map<String, String> allParams, List<String> completedSteps, Memory memory, String conversationId, int stepsExecuted, AtomicInteger traceNumber, ActionListener<Object> finalListener) {
        int maxSteps = Integer.parseInt(allParams.getOrDefault(MAX_STEPS_EXECUTED_FIELD, DEFAULT_MAX_STEPS_EXECUTED));
        String parentInteractionId = allParams.get("parent_interaction_id");
        if (stepsExecuted >= maxSteps) {
            String finalResult = String.format("Max Steps Limit Reached. Use memory_id with same task to restart. \n Last executed step: %s, \n Last executed step result: %s", completedSteps.get(completedSteps.size() - 2), completedSteps.getLast());
            this.saveAndReturnFinalResult((ConversationIndexMemory)memory, parentInteractionId, finalResult, completedSteps.get(completedSteps.size() - 2), finalListener);
            return;
        }
        MLPredictionTaskRequest request = new MLPredictionTaskRequest(llm.getModelId(), RemoteInferenceMLInput.builder().algorithm(FunctionName.REMOTE).inputDataset((MLInputDataset)RemoteInferenceInputDataSet.builder().parameters(allParams).build()).build(), null, allParams.get(TENANT_ID_FIELD));
        StepListener planListener = new StepListener();
        planListener.whenComplete(llmOutput -> {
            ModelTensorOutput modelTensorOutput = (ModelTensorOutput)llmOutput.getOutput();
            Map<String, String> parseLLMOutput = this.parseLLMOutput(allParams, modelTensorOutput);
            if (parseLLMOutput.get(RESULT_FIELD) != null) {
                String finalResult = parseLLMOutput.get(RESULT_FIELD);
                this.saveAndReturnFinalResult((ConversationIndexMemory)memory, parentInteractionId, finalResult, null, finalListener);
            } else {
                List<String> steps = Arrays.stream(parseLLMOutput.get(STEPS_FIELD).split(", ")).toList();
                this.addSteps(steps, allParams, STEPS_FIELD);
                String stepToExecute = steps.getFirst();
                String reActAgentId = (String)allParams.get(REACT_AGENT_ID_FIELD);
                HashMap<String, String> reactParams = new HashMap<String, String>();
                reactParams.put(QUESTION_FIELD, stepToExecute);
                if (allParams.containsKey(REACT_AGENT_MEMORY_ID_FIELD)) {
                    reactParams.put(MEMORY_ID_FIELD, (String)allParams.get(REACT_AGENT_MEMORY_ID_FIELD));
                }
                reactParams.put(SYSTEM_PROMPT_FIELD, allParams.getOrDefault(REACT_SYSTEM_PROMPT_FIELD, DEFAULT_REACT_SYSTEM_PROMPT));
                reactParams.put("llm_response_filter", (String)allParams.get("llm_response_filter"));
                AgentMLInput agentInput = AgentMLInput.AgentMLInputBuilder().agentId(reActAgentId).functionName(FunctionName.AGENT).inputDataset((MLInputDataset)RemoteInferenceInputDataSet.builder().parameters(reactParams).build()).build();
                MLExecuteTaskRequest executeRequest = new MLExecuteTaskRequest(FunctionName.AGENT, (Input)agentInput);
                this.client.execute((ActionType)MLExecuteTaskAction.INSTANCE, (ActionRequest)executeRequest, ActionListener.wrap(executeResponse -> {
                    ModelTensorOutput reactResult = (ModelTensorOutput)executeResponse.getOutput();
                    HashMap results = new HashMap();
                    reactResult.getMlModelOutputs().stream().flatMap(output -> output.getMlModelTensors().stream()).forEach(tensor -> {
                        if (MEMORY_ID_FIELD.equals(tensor.getName())) {
                            results.put(MEMORY_ID_FIELD, tensor.getResult());
                        } else {
                            Map dataMap = tensor.getDataAsMap();
                            if (dataMap != null && dataMap.containsKey(RESPONSE_FIELD)) {
                                results.put(STEP_RESULT_FIELD, (String)dataMap.get(RESPONSE_FIELD));
                            }
                        }
                    });
                    if (!results.containsKey(STEP_RESULT_FIELD)) {
                        throw new IllegalStateException("No valid response found in ReAct agent output");
                    }
                    String reActMemoryId = (String)results.get(MEMORY_ID_FIELD);
                    if (reActMemoryId != null && !reActMemoryId.isEmpty()) {
                        allParams.put(REACT_AGENT_MEMORY_ID_FIELD, reActMemoryId);
                    }
                    completedSteps.add(stepToExecute);
                    completedSteps.add((String)results.get(STEP_RESULT_FIELD));
                    MLChatAgentRunner.saveTraceData((ConversationIndexMemory)memory, memory.getType(), stepToExecute, (String)results.get(STEP_RESULT_FIELD), conversationId, false, parentInteractionId, traceNumber, "PlanExecuteReflect Agent");
                    this.addSteps(completedSteps, allParams, COMPLETED_STEPS_FIELD);
                    this.useRevalPromptTemplate(allParams);
                    this.executePlanningLoop(llm, allParams, completedSteps, memory, conversationId, stepsExecuted + 1, traceNumber, finalListener);
                }, e -> {
                    log.error("Failed to execute ReAct agent", (Throwable)e);
                    finalListener.onFailure(e);
                }));
            }
        }, e -> {
            log.error("Failed to run deep research agent", (Throwable)e);
            finalListener.onFailure(e);
        });
        this.client.execute((ActionType)MLPredictionTaskAction.INSTANCE, (ActionRequest)request, (ActionListener)planListener);
    }

    private Map<String, String> parseLLMOutput(Map<String, String> allParams, ModelTensorOutput modelTensorOutput) {
        String result;
        String llmResponse;
        HashMap<String, String> modelOutput = new HashMap<String, String>();
        Map dataAsMap = ((ModelTensor)((ModelTensors)modelTensorOutput.getMlModelOutputs().getFirst()).getMlModelTensors().getFirst()).getDataAsMap();
        if (dataAsMap.size() == 1 && dataAsMap.containsKey(RESPONSE_FIELD)) {
            llmResponse = ((String)dataAsMap.get(RESPONSE_FIELD)).trim();
        } else {
            if (!allParams.containsKey("llm_response_filter") || allParams.get("llm_response_filter").isEmpty()) {
                throw new IllegalArgumentException("llm_response_filter not found. Please provide the path to the model output.");
            }
            llmResponse = ((String)JsonPath.read((Object)dataAsMap, (String)allParams.get("llm_response_filter"), (Predicate[])new Predicate[0])).trim();
        }
        String json = StringUtils.isJson((String)llmResponse) ? llmResponse : this.extractJsonFromMarkdown(llmResponse);
        Map parsedJson = StringUtils.fromJson((String)json, (String)RESPONSE_FIELD);
        if (!parsedJson.containsKey(STEPS_FIELD) && !parsedJson.containsKey(RESULT_FIELD)) {
            throw new IllegalArgumentException("Missing required fields 'steps' and 'result' in JSON response");
        }
        if (parsedJson.containsKey(STEPS_FIELD)) {
            List steps = (List)parsedJson.get(STEPS_FIELD);
            modelOutput.put(STEPS_FIELD, String.join((CharSequence)", ", steps));
        }
        if (parsedJson.containsKey(RESULT_FIELD) && !(result = (String)parsedJson.get(RESULT_FIELD)).isEmpty()) {
            modelOutput.put(RESULT_FIELD, result);
        }
        return modelOutput;
    }

    private String extractJsonFromMarkdown(String response) {
        if ((response = response.trim()).contains("```json") && (response = response.substring(response.indexOf("```json") + "```json".length())).contains("```")) {
            response = response.substring(0, response.lastIndexOf("```"));
        }
        if (!StringUtils.isJson((String)(response = response.trim()))) {
            throw new IllegalStateException("Failed to parse LLM output due to invalid JSON");
        }
        return response;
    }

    private void addToolsToPrompt(Map<String, Tool> tools, Map<String, String> allParams) {
        StringBuilder toolsPrompt = new StringBuilder("In this environment, you have access to the below tools: \n");
        for (Map.Entry<String, Tool> entry : tools.entrySet()) {
            String toolName = entry.getKey();
            String toolDescription = entry.getValue().getDescription();
            toolsPrompt.append("- ").append(toolName).append(": ").append(toolDescription).append("\n").append("\n");
        }
        allParams.put(DEFAULT_PROMPT_TOOLS_FIELD, toolsPrompt.toString());
        this.populatePrompt(allParams);
    }

    private void addSteps(List<String> steps, Map<String, String> allParams, String field) {
        allParams.put(field, String.join((CharSequence)", ", steps));
    }

    private void saveAndReturnFinalResult(ConversationIndexMemory memory, String parentInteractionId, String finalResult, String input, ActionListener<Object> finalListener) {
        HashMap<String, Object> updateContent = new HashMap<String, Object>();
        updateContent.put(RESPONSE_FIELD, finalResult);
        if (input != null) {
            updateContent.put("input", input);
        }
        memory.getMemoryManager().updateInteraction(parentInteractionId, updateContent, (ActionListener<UpdateResponse>)ActionListener.wrap(res -> {
            List<ModelTensors> finalModelTensors = MLPlanExecuteAndReflectAgentRunner.createModelTensors(memory.getConversationId(), parentInteractionId);
            finalModelTensors.add(ModelTensors.builder().mlModelTensors(List.of(ModelTensor.builder().name(RESPONSE_FIELD).dataAsMap(Map.of(RESPONSE_FIELD, finalResult)).build())).build());
            finalListener.onResponse((Object)ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
        }, e -> {
            log.error("Failed to update interaction with final result", (Throwable)e);
            finalListener.onFailure(e);
        }));
    }

    private static List<ModelTensors> createModelTensors(String sessionId, String parentInteractionId) {
        ArrayList<ModelTensors> modelTensors = new ArrayList<ModelTensors>();
        modelTensors.add(ModelTensors.builder().mlModelTensors(List.of(ModelTensor.builder().name(MEMORY_ID_FIELD).result(sessionId).build(), ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build())).build());
        return modelTensors;
    }
}

