/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.rest;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.action.stats.MLStatsNodesAction;
import org.opensearch.ml.action.stats.MLStatsNodesRequest;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;

public class RestStatsMLAction
extends BaseRestHandler {
    private static final String STATS_ML_ACTION = "stats_ml";
    private MLStats mlStats;

    public RestStatsMLAction(MLStats mlStats) {
        this.mlStats = mlStats;
    }

    public String getName() {
        return STATS_ML_ACTION;
    }

    public List<RestHandler.Route> routes() {
        return ImmutableList.of((Object)new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/{nodeId}/stats/"), (Object)new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/{nodeId}/stats/{stat}"), (Object)new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/stats/"), (Object)new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/stats/{stat}"));
    }

    protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) {
        MLStatsNodesRequest mlStatsNodesRequest = this.getRequest(request);
        return channel -> client.execute((ActionType)MLStatsNodesAction.INSTANCE, (ActionRequest)mlStatsNodesRequest, (ActionListener)new RestToXContentListener(channel));
    }

    @VisibleForTesting
    MLStatsNodesRequest getRequest(RestRequest request) {
        MLStatsNodesRequest mlStatsRequest = new MLStatsNodesRequest(this.splitCommaSeparatedParam(request, "nodeId").orElse(null));
        mlStatsRequest.timeout(request.param("timeout"));
        List requestedStats = this.splitCommaSeparatedParam(request, "stat").map(Arrays::asList).orElseGet(Collections::emptyList);
        Set<String> validStats = this.mlStats.getStats().keySet();
        if (this.isAllStatsRequested(requestedStats)) {
            mlStatsRequest.setRetrieveAllStats(true);
        } else {
            mlStatsRequest.addAll(this.getStatsToBeRetrieved(request, validStats, requestedStats));
        }
        return mlStatsRequest;
    }

    @VisibleForTesting
    Set<String> getStatsToBeRetrieved(RestRequest request, Set<String> validStats, List<String> requestedStats) {
        if (requestedStats.contains("_all")) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Request %s contains both %s and individual stats", request.path(), "_all"));
        }
        Set invalidStats = requestedStats.stream().filter(s -> !validStats.contains(s)).collect(Collectors.toSet());
        if (!invalidStats.isEmpty()) {
            throw new IllegalArgumentException(this.unrecognized(request, invalidStats, new HashSet<String>(requestedStats), "stat"));
        }
        return new HashSet<String>(requestedStats);
    }

    @VisibleForTesting
    boolean isAllStatsRequested(List<String> requestedStats) {
        return requestedStats.isEmpty() || requestedStats.size() == 1 && requestedStats.contains("_all");
    }

    @VisibleForTesting
    Optional<String[]> splitCommaSeparatedParam(RestRequest request, String paramName) {
        return Optional.ofNullable(request.param(paramName)).map(s -> s.split(","));
    }
}

