/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.query;

import java.util.Locale;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.BaseQueryFactory;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.common.QueryUtils;
import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery;
import org.opensearch.knn.index.query.lucenelib.NestedKnnVectorQueryFactory;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;

public class KNNQueryFactory
extends BaseQueryFactory {
    @Generated
    private static final Logger log = LogManager.getLogger(KNNQueryFactory.class);

    public static Query create(BaseQueryFactory.CreateQueryRequest createQueryRequest) {
        String indexName = createQueryRequest.getIndexName();
        String fieldName = createQueryRequest.getFieldName();
        int k = createQueryRequest.getK();
        float[] vector = createQueryRequest.getVector();
        byte[] byteVector = createQueryRequest.getByteVector();
        VectorDataType vectorDataType = createQueryRequest.getVectorDataType();
        Query filterQuery = KNNQueryFactory.getFilterQuery(createQueryRequest);
        Map<String, ?> methodParameters = createQueryRequest.getMethodParameters();
        RescoreContext rescoreContext = createQueryRequest.getRescoreContext().orElse(null);
        boolean expandNested = createQueryRequest.getExpandNested().orElse(false);
        boolean memoryOptimizedSearchSupported = createQueryRequest.isMemoryOptimizedSearchSupported();
        BitSetProducer parentFilter = null;
        int shardId = -1;
        if (createQueryRequest.getContext().isPresent()) {
            QueryShardContext context = createQueryRequest.getContext().get();
            parentFilter = context.getParentFilter();
            shardId = context.getShardId();
        }
        if (parentFilter == null && expandNested) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Invalid value provided for the [%s] field. [%s] is only supported with a nested field.", "expand_nested_docs", "expand_nested_docs"));
        }
        if (!memoryOptimizedSearchSupported && KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
            Query validatedFilterQuery = KNNQueryFactory.validateFilterQuerySupport(filterQuery, createQueryRequest.getKnnEngine());
            log.debug("Creating custom k-NN query for index:{}, field:{}, k:{}, filterQuery:{}, efSearch:{}", (Object)indexName, (Object)fieldName, (Object)k, (Object)validatedFilterQuery, methodParameters);
            KNNQuery knnQuery = null;
            switch (vectorDataType) {
                case BINARY: {
                    knnQuery = KNNQuery.builder().field(fieldName).byteQueryVector(byteVector).indexName(indexName).parentsFilter(parentFilter).k(k).methodParameters(methodParameters).filterQuery(validatedFilterQuery).vectorDataType(vectorDataType).rescoreContext(rescoreContext).shardId(shardId).build();
                    break;
                }
                default: {
                    knnQuery = KNNQuery.builder().field(fieldName).queryVector(vector).indexName(indexName).parentsFilter(parentFilter).k(k).methodParameters(methodParameters).filterQuery(validatedFilterQuery).vectorDataType(vectorDataType).rescoreContext(rescoreContext).shardId(shardId).build();
                }
            }
            return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested);
        }
        Integer requestEfSearch = null;
        if (methodParameters != null && methodParameters.containsKey("ef_search")) {
            requestEfSearch = (Integer)methodParameters.get("ef_search");
        }
        int luceneK = requestEfSearch == null ? k : Math.max(k, requestEfSearch);
        log.debug("Creating Lucene k-NN query for index: {}, field:{}, k: {}", (Object)indexName, (Object)fieldName, (Object)k);
        switch (vectorDataType) {
            case BINARY: 
            case BYTE: {
                return new LuceneEngineKnnVectorQuery(KNNQueryFactory.getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter, expandNested));
            }
            case FLOAT: {
                return new LuceneEngineKnnVectorQuery(KNNQueryFactory.getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter, expandNested));
            }
        }
        throw new IllegalArgumentException(String.format(Locale.ROOT, "Invalid value provided for [%s] field. Supported values are [%s], but got: %s", new Object[]{"data_type", VectorDataType.SUPPORTED_VECTOR_DATA_TYPES, vectorDataType}));
    }

    private static Query validateFilterQuerySupport(Query filterQuery, KNNEngine knnEngine) {
        log.debug("filter query {}, knnEngine {}", (Object)filterQuery, (Object)knnEngine);
        if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(knnEngine)) {
            return filterQuery;
        }
        return null;
    }

    private static Query getKnnByteVectorQuery(String fieldName, byte[] byteVector, int k, Query filterQuery, BitSetProducer parentFilter, boolean expandNested) {
        if (parentFilter == null) {
            assert (!expandNested) : "expandNested is allowed to be true only for nested fields.";
            return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery);
        }
        return NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(fieldName, byteVector, k, filterQuery, parentFilter, expandNested);
    }

    private static Query getKnnFloatVectorQuery(String fieldName, float[] floatVector, int k, Query filterQuery, BitSetProducer parentFilter, boolean expandNested) {
        if (parentFilter == null) {
            assert (!expandNested) : "expandNested is allowed to be true only for nested fields.";
            return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery);
        }
        return NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(fieldName, floatVector, k, filterQuery, parentFilter, expandNested);
    }
}

