/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.agent;

import java.io.IOException;
import java.security.AccessController;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import lombok.Generated;
import org.apache.commons.text.StringEscapeUtils;
import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.StepListener;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLMemorySpec;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.memory.Memory;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.algorithms.agent.AgentUtils;
import org.opensearch.ml.engine.algorithms.agent.MLAgentRunner;
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;
import org.opensearch.transport.client.Client;

public class MLFlowAgentRunner
implements MLAgentRunner {
    @Generated
    private static final Logger log = LogManager.getLogger(MLFlowAgentRunner.class);
    private Client client;
    private Settings settings;
    private ClusterService clusterService;
    private NamedXContentRegistry xContentRegistry;
    private Map<String, Tool.Factory> toolFactories;
    private Map<String, Memory.Factory> memoryFactoryMap;

    public MLFlowAgentRunner(Client client, Settings settings, ClusterService clusterService, NamedXContentRegistry xContentRegistry, Map<String, Tool.Factory> toolFactories, Map<String, Memory.Factory> memoryFactoryMap) {
        this.client = client;
        this.settings = settings;
        this.clusterService = clusterService;
        this.xContentRegistry = xContentRegistry;
        this.toolFactories = toolFactories;
        this.memoryFactoryMap = memoryFactoryMap;
    }

    @Override
    public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener) {
        List<MLToolSpec> toolSpecs = AgentUtils.getMlToolSpecs(mlAgent, params);
        StepListener firstStepListener = null;
        Tool firstTool = null;
        ArrayList flowAgentOutput = new ArrayList();
        Map<String, String> firstToolExecuteParams = null;
        StepListener previousStepListener = null;
        ConcurrentHashMap additionalInfo = new ConcurrentHashMap();
        if (toolSpecs == null || toolSpecs.isEmpty()) {
            listener.onFailure((Exception)new IllegalArgumentException("no tool configured"));
            return;
        }
        MLMemorySpec memorySpec = mlAgent.getMemory();
        String memoryId = params.get("memory_id");
        String parentInteractionId = params.get("parent_interaction_id");
        for (int i = 0; i <= toolSpecs.size(); ++i) {
            if (i == 0) {
                MLToolSpec toolSpec = toolSpecs.get(i);
                Tool tool = this.createTool(toolSpec, mlAgent.getTenantId());
                previousStepListener = firstStepListener = new StepListener();
                firstTool = tool;
                firstToolExecuteParams = this.getToolExecuteParams(toolSpec, params, mlAgent.getTenantId());
                continue;
            }
            MLToolSpec previousToolSpec = toolSpecs.get(i - 1);
            StepListener nextStepListener = new StepListener();
            int finalI = i;
            previousStepListener.whenComplete(output -> {
                String key = AgentUtils.getToolName(previousToolSpec);
                String outputKey = key + ".output";
                String outputResponse = this.parseResponse(output);
                params.put(outputKey, StringEscapeUtils.escapeJson((String)outputResponse));
                if (previousToolSpec.isIncludeOutputInAgentResponse() || finalI == toolSpecs.size()) {
                    if (output instanceof ModelTensorOutput) {
                        flowAgentOutput.addAll(((ModelTensors)((ModelTensorOutput)output).getMlModelOutputs().get(0)).getMlModelTensors());
                    } else {
                        String result = output instanceof String ? (String)output : AccessController.doPrivileged(() -> StringUtils.toJson((Object)output));
                        ModelTensor stepOutput = ModelTensor.builder().name(key).result(result).build();
                        flowAgentOutput.add(stepOutput);
                    }
                    additionalInfo.put(outputKey, outputResponse);
                }
                if (finalI == toolSpecs.size()) {
                    if (memoryId == null || parentInteractionId == null || memorySpec == null || memorySpec.getType() == null) {
                        listener.onResponse((Object)flowAgentOutput);
                    } else {
                        ActionListener updateListener = ActionListener.wrap(updateResponse -> {
                            log.info("Updated additional info for interaction ID: {} in the flow agent.", (Object)updateResponse.getId());
                            listener.onResponse((Object)flowAgentOutput);
                        }, e -> {
                            log.error("Failed to update root interaction", (Throwable)e);
                            listener.onResponse((Object)flowAgentOutput);
                        });
                        this.updateMemoryWithListener(additionalInfo, memorySpec, memoryId, parentInteractionId, updateListener);
                    }
                    return;
                }
                MLToolSpec toolSpec = (MLToolSpec)toolSpecs.get(finalI);
                Tool tool = this.createTool(toolSpec, mlAgent.getTenantId());
                if (finalI < toolSpecs.size()) {
                    tool.run(this.getToolExecuteParams(toolSpec, params, mlAgent.getTenantId()), (ActionListener)nextStepListener);
                }
            }, e -> {
                log.error("Failed to run flow agent", (Throwable)e);
                listener.onFailure(e);
            });
            previousStepListener = nextStepListener;
        }
        if (toolSpecs.size() == 1) {
            firstTool.run(firstToolExecuteParams, listener);
        } else {
            firstTool.run(firstToolExecuteParams, firstStepListener);
        }
    }

    @VisibleForTesting
    void updateMemory(Map<String, Object> additionalInfo, MLMemorySpec memorySpec, String memoryId, String interactionId) {
        if (memoryId == null || interactionId == null || memorySpec == null || memorySpec.getType() == null) {
            return;
        }
        ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory)this.memoryFactoryMap.get(memorySpec.getType());
        conversationIndexMemoryFactory.create(memoryId, (ActionListener<ConversationIndexMemory>)ActionListener.wrap(memory -> this.updateInteraction(additionalInfo, interactionId, (ConversationIndexMemory)memory), e -> log.error("Failed create memory from id: {}", (Object)memoryId, e)));
    }

    @VisibleForTesting
    void updateMemoryWithListener(Map<String, Object> additionalInfo, MLMemorySpec memorySpec, String memoryId, String interactionId, ActionListener listener) {
        if (memoryId == null || interactionId == null || memorySpec == null || memorySpec.getType() == null) {
            return;
        }
        ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory)this.memoryFactoryMap.get(memorySpec.getType());
        conversationIndexMemoryFactory.create(memoryId, (ActionListener<ConversationIndexMemory>)ActionListener.wrap(memory -> this.updateInteractionWithListener(additionalInfo, interactionId, (ConversationIndexMemory)memory, listener), e -> log.error("Failed create memory from id: {}", (Object)memoryId, e)));
    }

    @VisibleForTesting
    void updateInteraction(Map<String, Object> additionalInfo, String interactionId, ConversationIndexMemory memory) {
        memory.getMemoryManager().updateInteraction(interactionId, (Map<String, Object>)ImmutableMap.of((Object)"additional_info", additionalInfo), (ActionListener<UpdateResponse>)ActionListener.wrap(updateResponse -> log.info("Updated additional info for interaction ID: {}", (Object)interactionId), e -> log.error("Failed to update root interaction", (Throwable)e)));
    }

    @VisibleForTesting
    void updateInteractionWithListener(Map<String, Object> additionalInfo, String interactionId, ConversationIndexMemory memory, ActionListener listener) {
        memory.getMemoryManager().updateInteraction(interactionId, (Map<String, Object>)ImmutableMap.of((Object)"additional_info", additionalInfo), (ActionListener<UpdateResponse>)listener);
    }

    @VisibleForTesting
    String parseResponse(Object output) throws IOException {
        if (output instanceof List && !((List)output).isEmpty() && ((List)output).get(0) instanceof ModelTensors) {
            ModelTensors tensors = (ModelTensors)((List)output).get(0);
            return tensors.toXContent(JsonXContent.contentBuilder(), null).toString();
        }
        if (output instanceof ModelTensor) {
            return ((ModelTensor)output).toXContent(JsonXContent.contentBuilder(), null).toString();
        }
        if (output instanceof ModelTensorOutput) {
            return ((ModelTensorOutput)output).toXContent(JsonXContent.contentBuilder(), null).toString();
        }
        if (output instanceof String) {
            return (String)output;
        }
        return StringUtils.toJson((Object)output);
    }

    @VisibleForTesting
    Tool createTool(MLToolSpec toolSpec, String tenantId) {
        HashMap<String, String> toolParams = new HashMap<String, String>();
        if (toolSpec.getParameters() != null) {
            toolParams.putAll(toolSpec.getParameters());
        }
        toolParams.put("tenant_id", tenantId);
        if (!this.toolFactories.containsKey(toolSpec.getType())) {
            throw new IllegalArgumentException("Tool not found: " + toolSpec.getType());
        }
        Tool tool = this.toolFactories.get(toolSpec.getType()).create(toolParams);
        if (toolSpec.getName() != null) {
            tool.setName(toolSpec.getName());
        }
        if (toolSpec.getDescription() != null) {
            tool.setDescription(toolSpec.getDescription());
        }
        return tool;
    }

    @VisibleForTesting
    Map<String, String> getToolExecuteParams(MLToolSpec toolSpec, Map<String, String> params, String tenantId) {
        HashMap<String, String> executeParams = new HashMap<String, String>();
        if (toolSpec.getParameters() != null) {
            executeParams.putAll(toolSpec.getParameters());
        }
        for (String key : params.keySet()) {
            String toBeReplaced = null;
            if (key.startsWith(toolSpec.getType() + ".")) {
                toBeReplaced = toolSpec.getType() + ".";
            }
            if (toolSpec.getName() != null && key.startsWith(toolSpec.getName() + ".")) {
                toBeReplaced = toolSpec.getName() + ".";
            }
            if (toBeReplaced != null) {
                executeParams.put(key.replace(toBeReplaced, ""), params.get(key));
                continue;
            }
            executeParams.put(key, params.get(key));
        }
        if (toolSpec.getConfigMap() != null && !toolSpec.getConfigMap().isEmpty()) {
            executeParams.putAll(toolSpec.getConfigMap());
        }
        executeParams.put("tenant_id", tenantId);
        if (executeParams.containsKey("input")) {
            String input = (String)executeParams.get("input");
            StringSubstitutor substitutor = new StringSubstitutor(executeParams, "${parameters.", "}");
            input = substitutor.replace(input);
            executeParams.put("input", input);
        }
        return executeParams;
    }

    @Generated
    public Client getClient() {
        return this.client;
    }

    @Generated
    public Settings getSettings() {
        return this.settings;
    }

    @Generated
    public ClusterService getClusterService() {
        return this.clusterService;
    }

    @Generated
    public NamedXContentRegistry getXContentRegistry() {
        return this.xContentRegistry;
    }

    @Generated
    public Map<String, Tool.Factory> getToolFactories() {
        return this.toolFactories;
    }

    @Generated
    public Map<String, Memory.Factory> getMemoryFactoryMap() {
        return this.memoryFactoryMap;
    }

    @Generated
    public void setClient(Client client) {
        this.client = client;
    }

    @Generated
    public void setSettings(Settings settings) {
        this.settings = settings;
    }

    @Generated
    public void setClusterService(ClusterService clusterService) {
        this.clusterService = clusterService;
    }

    @Generated
    public void setXContentRegistry(NamedXContentRegistry xContentRegistry) {
        this.xContentRegistry = xContentRegistry;
    }

    @Generated
    public void setToolFactories(Map<String, Tool.Factory> toolFactories) {
        this.toolFactories = toolFactories;
    }

    @Generated
    public void setMemoryFactoryMap(Map<String, Memory.Factory> memoryFactoryMap) {
        this.memoryFactoryMap = memoryFactoryMap;
    }

    @Generated
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof MLFlowAgentRunner)) {
            return false;
        }
        MLFlowAgentRunner other = (MLFlowAgentRunner)o;
        if (!other.canEqual(this)) {
            return false;
        }
        Client this$client = this.getClient();
        Client other$client = other.getClient();
        if (this$client == null ? other$client != null : !this$client.equals(other$client)) {
            return false;
        }
        Settings this$settings = this.getSettings();
        Settings other$settings = other.getSettings();
        if (this$settings == null ? other$settings != null : !this$settings.equals(other$settings)) {
            return false;
        }
        ClusterService this$clusterService = this.getClusterService();
        ClusterService other$clusterService = other.getClusterService();
        if (this$clusterService == null ? other$clusterService != null : !this$clusterService.equals(other$clusterService)) {
            return false;
        }
        NamedXContentRegistry this$xContentRegistry = this.getXContentRegistry();
        NamedXContentRegistry other$xContentRegistry = other.getXContentRegistry();
        if (this$xContentRegistry == null ? other$xContentRegistry != null : !this$xContentRegistry.equals(other$xContentRegistry)) {
            return false;
        }
        Map<String, Tool.Factory> this$toolFactories = this.getToolFactories();
        Map<String, Tool.Factory> other$toolFactories = other.getToolFactories();
        if (this$toolFactories == null ? other$toolFactories != null : !((Object)this$toolFactories).equals(other$toolFactories)) {
            return false;
        }
        Map<String, Memory.Factory> this$memoryFactoryMap = this.getMemoryFactoryMap();
        Map<String, Memory.Factory> other$memoryFactoryMap = other.getMemoryFactoryMap();
        return !(this$memoryFactoryMap == null ? other$memoryFactoryMap != null : !((Object)this$memoryFactoryMap).equals(other$memoryFactoryMap));
    }

    @Generated
    protected boolean canEqual(Object other) {
        return other instanceof MLFlowAgentRunner;
    }

    @Generated
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        Client $client = this.getClient();
        result = result * 59 + ($client == null ? 43 : $client.hashCode());
        Settings $settings = this.getSettings();
        result = result * 59 + ($settings == null ? 43 : $settings.hashCode());
        ClusterService $clusterService = this.getClusterService();
        result = result * 59 + ($clusterService == null ? 43 : $clusterService.hashCode());
        NamedXContentRegistry $xContentRegistry = this.getXContentRegistry();
        result = result * 59 + ($xContentRegistry == null ? 43 : $xContentRegistry.hashCode());
        Map<String, Tool.Factory> $toolFactories = this.getToolFactories();
        result = result * 59 + ($toolFactories == null ? 43 : ((Object)$toolFactories).hashCode());
        Map<String, Memory.Factory> $memoryFactoryMap = this.getMemoryFactoryMap();
        result = result * 59 + ($memoryFactoryMap == null ? 43 : ((Object)$memoryFactoryMap).hashCode());
        return result;
    }

    @Generated
    public String toString() {
        return "MLFlowAgentRunner(client=" + String.valueOf(this.getClient()) + ", settings=" + String.valueOf(this.getSettings()) + ", clusterService=" + String.valueOf(this.getClusterService()) + ", xContentRegistry=" + String.valueOf(this.getXContentRegistry()) + ", toolFactories=" + String.valueOf(this.getToolFactories()) + ", memoryFactoryMap=" + String.valueOf(this.getMemoryFactoryMap()) + ")";
    }

    @Generated
    public MLFlowAgentRunner() {
    }
}

