/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.codec.KNN990Codec;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOSupplier;
import org.apache.lucene.util.IOUtils;
import org.opensearch.common.UUIDs;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.codec.KNN990Codec.QuantizationConfigKNNCollector;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.memoryoptsearch.VectorSearcher;
import org.opensearch.knn.memoryoptsearch.VectorSearcherFactory;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig;

public class NativeEngines990KnnVectorsReader
extends KnnVectorsReader {
    @Generated
    private static final Logger log = LogManager.getLogger(NativeEngines990KnnVectorsReader.class);
    private static final int RESERVE_TWICE_SPACE = 2;
    private static final float SUFFICIENT_LOAD_FACTOR = 0.6f;
    private final FlatVectorsReader flatVectorsReader;
    private Map<String, String> quantizationStateCacheKeyPerField;
    private final SegmentReadState segmentReadState;
    private final List<String> cacheKeys;
    private volatile Map<String, VectorSearcher> vectorSearchers;

    public NativeEngines990KnnVectorsReader(SegmentReadState state, FlatVectorsReader flatVectorsReader) {
        this(state, flatVectorsReader, false);
    }

    public NativeEngines990KnnVectorsReader(SegmentReadState state, FlatVectorsReader flatVectorsReader, boolean memoryOptimizedSearchEnabled) {
        this.flatVectorsReader = flatVectorsReader;
        this.segmentReadState = state;
        this.cacheKeys = NativeEngines990KnnVectorsReader.getVectorCacheKeysFromSegmentReaderState(state);
        this.loadCacheKeyMap();
        if (memoryOptimizedSearchEnabled && state.context.context() != IOContext.Context.MERGE) {
            this.loadMemoryOptimizedSearcherIfRequired();
        }
    }

    public void checkIntegrity() throws IOException {
        this.flatVectorsReader.checkIntegrity();
    }

    public FloatVectorValues getFloatVectorValues(String field) throws IOException {
        return this.flatVectorsReader.getFloatVectorValues(field);
    }

    public ByteVectorValues getByteVectorValues(String field) throws IOException {
        return this.flatVectorsReader.getByteVectorValues(field);
    }

    public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
        if (knnCollector instanceof QuantizationConfigKNNCollector) {
            String cacheKey = this.quantizationStateCacheKeyPerField.get(field);
            FieldInfo fieldInfo = this.segmentReadState.fieldInfos.fieldInfo(field);
            QuantizationState quantizationState = QuantizationStateCacheManager.getInstance().getQuantizationState(new QuantizationStateReadConfig(this.segmentReadState, QuantizationService.getInstance().getQuantizationParams(fieldInfo), field, cacheKey));
            ((QuantizationConfigKNNCollector)knnCollector).setQuantizationState(quantizationState);
            return;
        }
        if (this.trySearchWithMemoryOptimizedSearch(field, target, knnCollector, acceptDocs, true)) {
            return;
        }
        throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader");
    }

    public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
        if (this.trySearchWithMemoryOptimizedSearch(field, target, knnCollector, acceptDocs, false)) {
            return;
        }
        throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader");
    }

    public void close() throws IOException {
        NativeMemoryCacheManager nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance();
        this.cacheKeys.forEach(nativeMemoryCacheManager::invalidate);
        ArrayList<Object> closeables = new ArrayList<Object>();
        closeables.add(this.flatVectorsReader);
        if (this.vectorSearchers != null) {
            closeables.addAll(this.vectorSearchers.values());
        }
        IOUtils.close(closeables);
        if (this.quantizationStateCacheKeyPerField != null) {
            QuantizationStateCacheManager quantizationStateCacheManager = QuantizationStateCacheManager.getInstance();
            for (String cacheKey : this.quantizationStateCacheKeyPerField.values()) {
                quantizationStateCacheManager.evict(cacheKey);
            }
        }
    }

    private boolean trySearchWithMemoryOptimizedSearch(String field, Object target, KnnCollector knnCollector, Bits acceptDocs, boolean isFloatVector) throws IOException {
        this.loadMemoryOptimizedSearcherIfRequired();
        VectorSearcher memoryOptimizedSearcher = this.vectorSearchers.get(field);
        if (memoryOptimizedSearcher != null) {
            if (isFloatVector) {
                memoryOptimizedSearcher.search((float[])target, knnCollector, acceptDocs);
            } else {
                memoryOptimizedSearcher.search((byte[])target, knnCollector, acceptDocs);
            }
            return true;
        }
        return false;
    }

    private void loadCacheKeyMap() {
        this.quantizationStateCacheKeyPerField = new HashMap<String, String>();
        for (FieldInfo fieldInfo : this.segmentReadState.fieldInfos) {
            String cacheKey = UUIDs.base64UUID();
            this.quantizationStateCacheKeyPerField.put(fieldInfo.getName(), cacheKey);
        }
    }

    private static List<String> getVectorCacheKeysFromSegmentReaderState(SegmentReadState segmentReadState) {
        ArrayList<String> cacheKeys = new ArrayList<String>();
        for (FieldInfo field : segmentReadState.fieldInfos) {
            String vectorIndexFileName = KNNCodecUtil.getNativeEngineFileFromFieldInfo(field, segmentReadState.segmentInfo);
            if (vectorIndexFileName == null) continue;
            String cacheKey = NativeMemoryCacheKeyHelper.constructCacheKey(vectorIndexFileName, segmentReadState.segmentInfo);
            cacheKeys.add(cacheKey);
        }
        return cacheKeys;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void loadMemoryOptimizedSearcherIfRequired() {
        if (this.vectorSearchers != null) {
            return;
        }
        NativeEngines990KnnVectorsReader nativeEngines990KnnVectorsReader = this;
        synchronized (nativeEngines990KnnVectorsReader) {
            if (this.vectorSearchers != null) {
                return;
            }
            HashMap<String, VectorSearcher> vectorSearcherPerField = new HashMap<String, VectorSearcher>(2 * this.segmentReadState.fieldInfos.size(), 0.6f);
            try {
                for (FieldInfo fieldInfo : this.segmentReadState.fieldInfos) {
                    VectorSearcher searcher;
                    IOSupplier<VectorSearcher> searcherSupplier = this.getVectorSearcherSupplier(fieldInfo);
                    if (searcherSupplier == null || (searcher = (VectorSearcher)searcherSupplier.get()) == null) continue;
                    vectorSearcherPerField.put(fieldInfo.getName(), searcher);
                }
                this.vectorSearchers = vectorSearcherPerField;
            }
            catch (Exception e) {
                try {
                    IOUtils.closeWhileHandlingException(vectorSearcherPerField.values());
                }
                catch (Exception closeException) {
                    log.error(closeException.getMessage(), (Throwable)closeException);
                }
                throw new RuntimeException(e);
            }
        }
    }

    private IOSupplier<VectorSearcher> getVectorSearcherSupplier(FieldInfo fieldInfo) {
        Map attributes = fieldInfo.attributes();
        if (attributes == null || !attributes.containsKey("knn_field")) {
            return null;
        }
        KNNEngine knnEngine = FieldInfoExtractor.extractKNNEngine(fieldInfo);
        if (knnEngine == null) {
            return null;
        }
        VectorSearcherFactory searcherFactory = knnEngine.getVectorSearcherFactory();
        if (searcherFactory == null) {
            return null;
        }
        String fileName = KNNCodecUtil.getNativeEngineFileFromFieldInfo(fieldInfo, this.segmentReadState.segmentInfo);
        if (fileName != null) {
            return () -> searcherFactory.createVectorSearcher(this.segmentReadState.directory, fileName);
        }
        return null;
    }
}

