/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.plugin.rest;

import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.common.xcontent.XContentParserUtils;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.plugin.transport.TrainingJobRouterAction;
import org.opensearch.knn.plugin.transport.TrainingModelRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;

public class RestTrainModelHandler
extends BaseRestHandler {
    private static final String NAME = "knn_train_model_action";
    private static final Object DEFAULT_NOT_SET_OBJECT_VALUE = null;
    private static final int DEFAULT_NOT_SET_INT_VALUE = -1;

    public String getName() {
        return NAME;
    }

    public List<RestHandler.Route> routes() {
        return ImmutableList.of((Object)new RestHandler.Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/%s/{%s}/_train", "/_plugins/_knn", "models", "model_id")), (Object)new RestHandler.Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/%s/_train", "/_plugins/_knn", "models")));
    }

    protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
        TrainingModelRequest trainingModelRequest = this.createTransportRequest(restRequest);
        return channel -> client.execute((ActionType)TrainingJobRouterAction.INSTANCE, (ActionRequest)trainingModelRequest, (ActionListener)new RestToXContentListener(channel));
    }

    private TrainingModelRequest createTransportRequest(RestRequest restRequest) throws IOException {
        String modelId = restRequest.param("model_id");
        String preferredNodeId = restRequest.param("preference");
        XContentParser parser = restRequest.contentParser();
        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.nextToken(), (XContentParser)parser);
        KNNMethodContext knnMethodContext = (KNNMethodContext)DEFAULT_NOT_SET_OBJECT_VALUE;
        String trainingIndex = (String)DEFAULT_NOT_SET_OBJECT_VALUE;
        String trainingField = (String)DEFAULT_NOT_SET_OBJECT_VALUE;
        String description = (String)DEFAULT_NOT_SET_OBJECT_VALUE;
        int dimension = -1;
        int maximumVectorCount = -1;
        int searchSize = -1;
        while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            String fieldName = parser.currentName();
            parser.nextToken();
            if ("training_index".equals(fieldName) && this.ensureNotSet(fieldName, trainingIndex)) {
                trainingIndex = parser.textOrNull();
                continue;
            }
            if ("training_field".equals(fieldName) && this.ensureNotSet(fieldName, trainingField)) {
                trainingField = parser.textOrNull();
                continue;
            }
            if ("method".equals(fieldName) && this.ensureNotSet(fieldName, knnMethodContext)) {
                knnMethodContext = KNNMethodContext.parse(parser.map());
                continue;
            }
            if ("dimension".equals(fieldName) && this.ensureNotSet(fieldName, dimension)) {
                dimension = (Integer)NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
                continue;
            }
            if ("max_training_vector_count".equals(fieldName) && this.ensureNotSet(fieldName, maximumVectorCount)) {
                maximumVectorCount = (Integer)NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
                continue;
            }
            if ("search_size".equals(fieldName) && this.ensureNotSet(fieldName, searchSize)) {
                searchSize = (Integer)NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
                continue;
            }
            if ("description".equals(fieldName) && this.ensureNotSet(fieldName, description)) {
                description = parser.textOrNull();
                continue;
            }
            throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is not a valid parameter.");
        }
        this.ensureSet("method", knnMethodContext);
        this.ensureSet("dimension", dimension);
        this.ensureSet("training_index", trainingIndex);
        this.ensureSet("training_field", trainingField);
        if (description == DEFAULT_NOT_SET_OBJECT_VALUE) {
            description = "";
        }
        TrainingModelRequest trainingModelRequest = new TrainingModelRequest(modelId, knnMethodContext, dimension, trainingIndex, trainingField, preferredNodeId, description);
        if (maximumVectorCount != -1) {
            trainingModelRequest.setMaximumVectorCount(maximumVectorCount);
        }
        if (searchSize != -1) {
            trainingModelRequest.setSearchSize(searchSize);
        }
        return trainingModelRequest;
    }

    private void ensureSet(String fieldName, Object value) {
        if (value == DEFAULT_NOT_SET_OBJECT_VALUE) {
            throw new IllegalArgumentException("Request did not set \"" + fieldName + ".");
        }
    }

    private void ensureSet(String fieldName, int value) {
        if (value == -1) {
            throw new IllegalArgumentException("Request did not set \"" + fieldName + ".");
        }
    }

    private boolean ensureNotSet(String fieldName, Object value) {
        if (value != DEFAULT_NOT_SET_OBJECT_VALUE) {
            throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is duplicated.");
        }
        return true;
    }

    private boolean ensureNotSet(String fieldName, int value) {
        if (value != -1) {
            throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is duplicated.");
        }
        return true;
    }
}

