/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor;

import com.google.common.annotations.VisibleForTesting;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiConsumer;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.get.GetAction;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.MapInferenceRequest;
import org.opensearch.neuralsearch.processor.optimization.TextImageEmbeddingInferenceFilter;
import org.opensearch.transport.client.OpenSearchClient;

public class TextImageEmbeddingProcessor
extends AbstractProcessor {
    @Generated
    private static final Logger log = LogManager.getLogger(TextImageEmbeddingProcessor.class);
    public static final String TYPE = "text_image_embedding";
    public static final String MODEL_ID_FIELD = "model_id";
    public static final String EMBEDDING_FIELD = "embedding";
    public static final boolean DEFAULT_SKIP_EXISTING = false;
    public static final String SKIP_EXISTING = "skip_existing";
    public static final String FIELD_MAP_FIELD = "field_map";
    public static final String TEXT_FIELD_NAME = "text";
    public static final String IMAGE_FIELD_NAME = "image";
    public static final String INPUT_TEXT = "inputText";
    public static final String INPUT_IMAGE = "inputImage";
    private static final String INDEX_FIELD = "_index";
    private static final String ID_FIELD = "_id";
    private static final Set<String> VALID_FIELD_NAMES = Set.of("text", "image");
    private final String modelId;
    private final String embedding;
    private final Map<String, String> fieldMap;
    private final boolean skipExisting;
    private final OpenSearchClient openSearchClient;
    private final MLCommonsClientAccessor mlCommonsClientAccessor;
    private final TextImageEmbeddingInferenceFilter inferenceFilter;
    private final Environment environment;
    private final ClusterService clusterService;

    public TextImageEmbeddingProcessor(String tag, String description, String modelId, String embedding, Map<String, String> fieldMap, boolean skipExisting, TextImageEmbeddingInferenceFilter inferenceFilter, OpenSearchClient openSearchClient, MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService) {
        super(tag, description);
        if (StringUtils.isBlank((CharSequence)modelId)) {
            throw new IllegalArgumentException("model_id is null or empty, can not process it");
        }
        this.validateEmbeddingConfiguration(fieldMap);
        this.modelId = modelId;
        this.embedding = embedding;
        this.fieldMap = fieldMap;
        this.mlCommonsClientAccessor = clientAccessor;
        this.environment = environment;
        this.clusterService = clusterService;
        this.skipExisting = skipExisting;
        this.inferenceFilter = inferenceFilter;
        this.openSearchClient = openSearchClient;
    }

    private void validateEmbeddingConfiguration(Map<String, String> fieldMap) {
        if (fieldMap == null || fieldMap.isEmpty() || fieldMap.entrySet().stream().anyMatch(x -> StringUtils.isBlank((CharSequence)((CharSequence)x.getKey())) || Objects.isNull(x.getValue()))) {
            throw new IllegalArgumentException("Unable to create the TextImageEmbedding processor as field_map has invalid key or value");
        }
        if (fieldMap.entrySet().stream().anyMatch(entry -> !VALID_FIELD_NAMES.contains(entry.getKey()))) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Unable to create the TextImageEmbedding processor with provided field name(s). Following names are supported [%s]", String.join((CharSequence)",", VALID_FIELD_NAMES)));
        }
    }

    public IngestDocument execute(IngestDocument ingestDocument) {
        return ingestDocument;
    }

    public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
        try {
            Map<String, String> knnMap = this.buildMapWithKnnKeyAndOriginalValue(ingestDocument);
            Map<String, String> inferenceMap = this.createInferences(knnMap);
            if (inferenceMap.isEmpty()) {
                handler.accept(ingestDocument, null);
                return;
            }
            if (!this.skipExisting) {
                this.generateAndSetInference(ingestDocument, inferenceMap, handler);
                return;
            }
            Object index = ingestDocument.getSourceAndMetadata().get(INDEX_FIELD);
            Object id = ingestDocument.getSourceAndMetadata().get(ID_FIELD);
            if (Objects.isNull(index) || Objects.isNull(id)) {
                this.generateAndSetInference(ingestDocument, inferenceMap, handler);
                return;
            }
            this.openSearchClient.execute((ActionType)GetAction.INSTANCE, (ActionRequest)new GetRequest(index.toString(), id.toString()), ActionListener.wrap(response -> this.reuseOrGenerateEmbedding((GetResponse)response, ingestDocument, knnMap, inferenceMap, handler), e -> handler.accept((IngestDocument)null, (Exception)e)));
        }
        catch (Exception e2) {
            handler.accept(null, e2);
        }
    }

    private void setVectorFieldsToDocument(IngestDocument ingestDocument, List<Number> vectors) {
        Objects.requireNonNull(vectors, "embedding failed, inference returns null result!");
        log.debug("Text embedding result fetched, starting build vector output!");
        Map<String, Object> textEmbeddingResult = this.buildTextEmbeddingResult(this.embedding, vectors);
        textEmbeddingResult.forEach((arg_0, arg_1) -> ((IngestDocument)ingestDocument).setFieldValue(arg_0, arg_1));
    }

    private Map<String, String> createInferences(Map<String, String> knnKeyMap) {
        HashMap<String, String> texts = new HashMap<String, String>();
        if (this.fieldMap.containsKey(TEXT_FIELD_NAME) && knnKeyMap.containsKey(this.fieldMap.get(TEXT_FIELD_NAME))) {
            texts.put(INPUT_TEXT, knnKeyMap.get(this.fieldMap.get(TEXT_FIELD_NAME)));
        }
        if (this.fieldMap.containsKey(IMAGE_FIELD_NAME) && knnKeyMap.containsKey(this.fieldMap.get(IMAGE_FIELD_NAME))) {
            texts.put(INPUT_IMAGE, knnKeyMap.get(this.fieldMap.get(IMAGE_FIELD_NAME)));
        }
        return texts;
    }

    @VisibleForTesting
    Map<String, String> buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocument) {
        Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
        LinkedHashMap<String, String> mapWithKnnKeys = new LinkedHashMap<String, String>();
        for (Map.Entry<String, String> fieldMapEntry : this.fieldMap.entrySet()) {
            String originalKey = fieldMapEntry.getValue();
            if (!sourceAndMetadataMap.containsKey(originalKey)) continue;
            Object metaValue = sourceAndMetadataMap.get(originalKey);
            if (Objects.isNull(metaValue)) {
                throw new IllegalArgumentException(String.format(Locale.getDefault(), "Unsupported format of the field in the document, %s value must be a non-empty string. Currently it is null", originalKey));
            }
            if (!(metaValue instanceof String) || Objects.isNull(metaValue) || StringUtils.isBlank((CharSequence)((String)metaValue))) {
                throw new IllegalArgumentException(String.format(Locale.getDefault(), "Unsupported format of the field in the document, %s value must be a non-empty string. Currently it is '%s'. Type is %s", originalKey, metaValue, metaValue.getClass()));
            }
            mapWithKnnKeys.put(originalKey, (String)sourceAndMetadataMap.get(originalKey));
        }
        return mapWithKnnKeys;
    }

    @VisibleForTesting
    Map<String, Object> buildTextEmbeddingResult(String knnKey, List<Number> modelTensorList) {
        LinkedHashMap<String, Object> result = new LinkedHashMap<String, Object>();
        result.put(knnKey, modelTensorList);
        return result;
    }

    public String getType() {
        return TYPE;
    }

    private void generateAndSetInference(IngestDocument ingestDocument, Map<String, String> inferenceMap, BiConsumer<IngestDocument, Exception> handler) {
        this.mlCommonsClientAccessor.inferenceSentencesMap((MapInferenceRequest)((MapInferenceRequest.MapInferenceRequestBuilder)((MapInferenceRequest.MapInferenceRequestBuilder)MapInferenceRequest.builder().modelId(this.modelId)).inputObjects(inferenceMap)).build(), (ActionListener<List<Number>>)ActionListener.wrap(vectors -> {
            this.setVectorFieldsToDocument(ingestDocument, (List<Number>)vectors);
            handler.accept(ingestDocument, null);
        }, e -> handler.accept((IngestDocument)null, (Exception)e)));
    }

    private void reuseOrGenerateEmbedding(GetResponse response, IngestDocument ingestDocument, Map<String, String> knnMap, Map<String, String> inferenceMap, BiConsumer<IngestDocument, Exception> handler) {
        Map existingDocument = response.getSourceAsMap();
        if (existingDocument == null || existingDocument.isEmpty()) {
            this.generateAndSetInference(ingestDocument, inferenceMap, handler);
            return;
        }
        Map<String, String> filteredKnnMap = this.inferenceFilter.filterAndCopyExistingEmbeddings(ingestDocument, existingDocument, knnMap, this.embedding);
        Map<String, String> filteredInferenceMap = this.createInferences(filteredKnnMap);
        if (filteredInferenceMap.isEmpty()) {
            handler.accept(ingestDocument, null);
        } else {
            this.generateAndSetInference(ingestDocument, filteredInferenceMap, handler);
        }
    }
}

