/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.transport;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.Generated;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.nodes.BaseNodesResponse;
import org.opensearch.cluster.ClusterName;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.neuralsearch.stats.common.StatSnapshot;
import org.opensearch.neuralsearch.transport.NeuralStatsNodeResponse;

public class NeuralStatsResponse
extends BaseNodesResponse<NeuralStatsNodeResponse>
implements ToXContentObject {
    public static final String INFO_KEY_PREFIX = "info";
    public static final String NODES_KEY_PREFIX = "nodes";
    public static final String AGGREGATED_NODES_KEY_PREFIX = "all_nodes";
    private Map<String, StatSnapshot<?>> infoStats;
    private Map<String, StatSnapshot<?>> aggregatedNodeStats;
    private Map<String, Map<String, StatSnapshot<?>>> nodeIdToNodeEventStats;
    private boolean flatten;
    private boolean includeMetadata;

    public NeuralStatsResponse(StreamInput in) throws IOException {
        super(new ClusterName(in), in.readList(NeuralStatsNodeResponse::readStats), in.readList(FailedNodeException::new));
        Map castedInfoStats = in.readMap();
        Map castedAggregatedNodeStats = in.readMap();
        Map castedNodeIdToNodeEventStats = in.readMap();
        this.infoStats = castedInfoStats;
        this.aggregatedNodeStats = castedAggregatedNodeStats;
        this.nodeIdToNodeEventStats = castedNodeIdToNodeEventStats;
        this.flatten = in.readBoolean();
        this.includeMetadata = in.readBoolean();
    }

    public NeuralStatsResponse(ClusterName clusterName, List<NeuralStatsNodeResponse> nodes, List<FailedNodeException> failures, Map<String, StatSnapshot<?>> infoStats, Map<String, StatSnapshot<?>> aggregatedNodeStats, Map<String, Map<String, StatSnapshot<?>>> nodeIdToNodeEventStats, boolean flatten, boolean includeMetadata) {
        super(clusterName, nodes, failures);
        this.infoStats = infoStats;
        this.aggregatedNodeStats = aggregatedNodeStats;
        this.nodeIdToNodeEventStats = nodeIdToNodeEventStats;
        this.flatten = flatten;
        this.includeMetadata = includeMetadata;
    }

    public void writeTo(StreamOutput out) throws IOException {
        super.writeTo(out);
        Map<String, StatSnapshot<?>> downcastedInfoStats = this.infoStats;
        Map<String, StatSnapshot<?>> downcastedAggregatedNodeStats = this.aggregatedNodeStats;
        Map<String, Map<String, StatSnapshot<?>>> downcastedNodeIdToNodeEventStats = this.nodeIdToNodeEventStats;
        out.writeMap(downcastedInfoStats);
        out.writeMap(downcastedAggregatedNodeStats);
        out.writeMap(downcastedNodeIdToNodeEventStats);
        out.writeBoolean(this.flatten);
        out.writeBoolean(this.includeMetadata);
    }

    public void writeNodesTo(StreamOutput out, List<NeuralStatsNodeResponse> nodes) throws IOException {
        out.writeList(nodes);
    }

    public List<NeuralStatsNodeResponse> readNodesFrom(StreamInput in) throws IOException {
        return in.readList(NeuralStatsNodeResponse::readStats);
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        Map<String, Object> formattedInfoStats = this.formatStats(this.infoStats);
        builder.startObject(INFO_KEY_PREFIX);
        builder.mapContents(formattedInfoStats);
        builder.endObject();
        Map<String, Object> formattedAggregatedNodeStats = this.formatStats(this.aggregatedNodeStats);
        builder.startObject(AGGREGATED_NODES_KEY_PREFIX);
        builder.mapContents(formattedAggregatedNodeStats);
        builder.endObject();
        Map<String, Object> formattedNodeEventStats = this.formatNodeEventStats(this.nodeIdToNodeEventStats);
        builder.startObject(NODES_KEY_PREFIX);
        builder.mapContents(formattedNodeEventStats);
        builder.endObject();
        return builder;
    }

    private Map<String, Object> formatStats(Map<String, StatSnapshot<?>> rawStats) {
        if (this.flatten) {
            return this.getFlattenedStats(rawStats);
        }
        return this.writeNestedMapWithDotNotation(rawStats, this.includeMetadata);
    }

    private Map<String, Object> getFlattenedStats(Map<String, StatSnapshot<?>> rawStats) {
        if (this.includeMetadata) {
            return rawStats;
        }
        return rawStats.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> ((StatSnapshot)entry.getValue()).getValue()));
    }

    private Map<String, Object> formatNodeEventStats(Map<String, Map<String, StatSnapshot<?>>> rawNodeStats) {
        HashMap<String, Object> formattedNodeIdsToNodeEventStats = new HashMap<String, Object>();
        for (Map.Entry<String, Map<String, StatSnapshot<?>>> nodeEventStats : rawNodeStats.entrySet()) {
            String nodeId = nodeEventStats.getKey();
            Map<String, Object> formattedNodeStats = this.formatStats(nodeEventStats.getValue());
            formattedNodeIdsToNodeEventStats.put(nodeId, formattedNodeStats);
        }
        return formattedNodeIdsToNodeEventStats;
    }

    private Map<String, Object> writeNestedMapWithDotNotation(Map<String, StatSnapshot<?>> dotMap, boolean includeMetadata) {
        HashMap<String, Object> nestedMap = new HashMap<String, Object>();
        for (Map.Entry<String, StatSnapshot<?>> entry : dotMap.entrySet()) {
            String[] parts = entry.getKey().split("\\.");
            Map<String, Object> current = nestedMap;
            for (int i = 0; i < parts.length - 1; ++i) {
                current = (Map)current.computeIfAbsent(parts[i], k -> new HashMap());
            }
            StatSnapshot<?> value = includeMetadata ? entry.getValue() : entry.getValue().getValue();
            current.put(parts[parts.length - 1], value);
        }
        return nestedMap;
    }

    @Generated
    public Map<String, StatSnapshot<?>> getInfoStats() {
        return this.infoStats;
    }

    @Generated
    public Map<String, StatSnapshot<?>> getAggregatedNodeStats() {
        return this.aggregatedNodeStats;
    }

    @Generated
    public Map<String, Map<String, StatSnapshot<?>>> getNodeIdToNodeEventStats() {
        return this.nodeIdToNodeEventStats;
    }

    @Generated
    public boolean isFlatten() {
        return this.flatten;
    }

    @Generated
    public boolean isIncludeMetadata() {
        return this.includeMetadata;
    }
}

