/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.forecast.ml;

import com.amazon.randomcutforest.parkservices.RCFCaster;
import com.amazon.randomcutforest.parkservices.state.RCFCasterMapper;
import com.amazon.randomcutforest.parkservices.state.RCFCasterState;
import com.google.gson.Gson;
import io.protostuff.LinkedBuffer;
import io.protostuff.ProtostuffIOUtil;
import io.protostuff.Schema;
import java.io.IOException;
import java.security.AccessController;
import java.time.Clock;
import java.time.Instant;
import java.time.ZoneOffset;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.pool2.impl.GenericObjectPool;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.support.IndicesOptions;
import org.opensearch.core.action.ActionListener;
import org.opensearch.forecast.indices.ForecastIndex;
import org.opensearch.forecast.indices.ForecastIndexManagement;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.DeleteByQueryAction;
import org.opensearch.index.reindex.DeleteByQueryRequest;
import org.opensearch.timeseries.ml.CheckpointDao;
import org.opensearch.timeseries.ml.ModelManager;
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.Sample;
import org.opensearch.timeseries.model.Entity;
import org.opensearch.timeseries.util.ClientUtil;
import org.opensearch.transport.client.Client;

public class ForecastCheckpointDao
extends CheckpointDao<RCFCaster, ForecastIndex, ForecastIndexManagement> {
    public static final Logger logger = LogManager.getLogger(ForecastCheckpointDao.class);
    static final String NOT_ABLE_TO_DELETE_CHECKPOINT_MSG = "Cannot delete all checkpoints of forecaster";
    RCFCasterMapper mapper;
    private Schema<RCFCasterState> rcfCasterSchema;

    public ForecastCheckpointDao(Client client, ClientUtil clientUtil, Gson gson, int maxCheckpointBytes, GenericObjectPool<LinkedBuffer> serializeRCFBufferPool, int serializeRCFBufferSize, ForecastIndexManagement indexUtil, RCFCasterMapper mapper, Schema<RCFCasterState> rcfCasterSchema, Clock clock) {
        super(client, clientUtil, ForecastIndex.CHECKPOINT.getIndexName(), gson, maxCheckpointBytes, serializeRCFBufferPool, serializeRCFBufferSize, indexUtil, clock);
        this.mapper = mapper;
        this.rcfCasterSchema = rcfCasterSchema;
    }

    public void putCasterCheckpoint(String modelId, RCFCaster caster, ActionListener<Void> listener) {
        HashMap<String, Object> source = new HashMap<String, Object>();
        Optional<String> modelCheckpoint = this.toCheckpoint(Optional.of(caster));
        if (modelCheckpoint.isPresent()) {
            source.put("model", modelCheckpoint.get());
            source.put("timestamp", this.clock.instant().atZone(ZoneOffset.UTC));
            source.put("timestamp", this.clock.instant().atZone(ZoneOffset.UTC));
            source.put("schema_version", this.indexUtil.getSchemaVersion(ForecastIndex.CHECKPOINT));
            this.putModelCheckpoint(modelId, source, listener);
        } else {
            listener.onFailure((Exception)new RuntimeException("Fail to create checkpoint to save"));
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private Optional<String> toCheckpoint(Optional<RCFCaster> caster) {
        Optional<String> checkpoint;
        block17: {
            if (caster.isEmpty()) {
                return Optional.empty();
            }
            checkpoint = Optional.empty();
            Map.Entry<LinkedBuffer, Boolean> result = this.checkoutOrNewBuffer();
            LinkedBuffer buffer = result.getKey();
            boolean needCheckin = result.getValue();
            try {
                checkpoint = this.toCheckpoint(caster, buffer);
            }
            catch (Exception e) {
                logger.error("Failed to serialize model", (Throwable)e);
                if (!needCheckin) break block17;
                try {
                    this.serializeRCFBufferPool.invalidateObject((Object)buffer);
                    needCheckin = false;
                }
                catch (Exception x) {
                    logger.warn("Failed to invalidate buffer", (Throwable)x);
                }
                try {
                    checkpoint = this.toCheckpoint(caster, LinkedBuffer.allocate((int)this.serializeRCFBufferSize));
                }
                catch (Exception ex) {
                    logger.warn("Failed to generate checkpoint", (Throwable)ex);
                }
            }
            finally {
                if (needCheckin) {
                    try {
                        this.serializeRCFBufferPool.returnObject((Object)buffer);
                    }
                    catch (Exception e) {
                        logger.warn("Failed to return buffer to pool", (Throwable)e);
                    }
                }
            }
        }
        return checkpoint;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private Optional<String> toCheckpoint(Optional<RCFCaster> caster, LinkedBuffer buffer) {
        if (caster.isEmpty()) {
            return Optional.empty();
        }
        try {
            byte[] bytes = AccessController.doPrivileged(() -> {
                RCFCasterState casterState = this.mapper.toState((RCFCaster)caster.get());
                return ProtostuffIOUtil.toByteArray((Object)casterState, this.rcfCasterSchema, (LinkedBuffer)buffer);
            });
            Optional<String> optional = Optional.ofNullable(Base64.getEncoder().encodeToString(bytes));
            return optional;
        }
        finally {
            buffer.clear();
        }
    }

    @Override
    public Map<String, Object> toIndexSource(ModelState<RCFCaster> modelState) throws IOException {
        HashMap<String, Object> source = new HashMap<String, Object>();
        Optional<RCFCaster> model = modelState.getModel();
        Optional<String> serializedModel = this.toCheckpoint(model);
        if (serializedModel.isPresent() && serializedModel.get().length() <= this.maxCheckpointBytes) {
            source.put("model", serializedModel.get());
        } else {
            logger.warn((Message)new ParameterizedMessage("[{}]'s model is empty or too large: [{}] bytes", (Object)modelState.getModelId(), (Object)(serializedModel.isPresent() ? serializedModel.get().length() : 0)));
        }
        Optional<Sample[]> samples = this.toCheckpoint(modelState.getSamples());
        if (samples.isPresent()) {
            source.put("samples", samples.get());
        }
        if (!source.containsKey("samples") && !source.containsKey("model")) {
            logger.info("nothing to save for [{}]", (Object)modelState.getModelId());
            return source;
        }
        source.put("forecaster_id", modelState.getConfigId());
        source.put("timestamp", this.clock.instant().atZone(ZoneOffset.UTC));
        source.put("schema_version", this.indexUtil.getSchemaVersion(ForecastIndex.CHECKPOINT));
        Optional<Entity> entity = modelState.getEntity();
        if (entity.isPresent()) {
            source.put("entity", entity.get());
        }
        return source;
    }

    private void deserializeRCFCasterModel(GetResponse response, String rcfModelId, ActionListener<Optional<RCFCaster>> listener) {
        Object model = null;
        if (response.isExists()) {
            try {
                model = response.getSource().get("model");
                listener.onResponse(Optional.ofNullable(this.toRCFCaster(model)));
            }
            catch (Exception e) {
                logger.error((Message)new ParameterizedMessage("Unexpected error when deserializing [{}]", (Object)rcfModelId), (Throwable)e);
                listener.onResponse(Optional.empty());
            }
        } else {
            listener.onResponse(Optional.empty());
        }
    }

    RCFCaster toRCFCaster(String checkpoint) {
        RCFCaster rcfCaster = null;
        if (checkpoint != null && checkpoint.length() > 0) {
            try {
                byte[] bytes = Base64.getDecoder().decode(checkpoint);
                RCFCasterState state = (RCFCasterState)this.rcfCasterSchema.newMessage();
                AccessController.doPrivileged(() -> {
                    ProtostuffIOUtil.mergeFrom((byte[])bytes, (Object)state, this.rcfCasterSchema);
                    return null;
                });
                rcfCaster = (RCFCaster)this.mapper.toModel((Object)state);
            }
            catch (RuntimeException e) {
                logger.error("Failed to deserialize RCFCaster model", (Throwable)e);
            }
        }
        return rcfCaster;
    }

    public void getCasterModel(String modelId, ActionListener<Optional<RCFCaster>> listener) {
        this.clientUtil.asyncRequest(new GetRequest(this.indexName, modelId), (arg_0, arg_1) -> ((Client)this.client).get(arg_0, arg_1), ActionListener.wrap(response -> this.deserializeRCFCasterModel((GetResponse)response, modelId, listener), exception -> {
            if (exception instanceof IndexNotFoundException) {
                listener.onResponse(Optional.empty());
            } else {
                listener.onFailure(exception);
            }
        }));
    }

    @Override
    protected ModelState<RCFCaster> fromEntityModelCheckpoint(Map<String, Object> checkpoint, String modelId, String configId) {
        try {
            return AccessController.doPrivileged(() -> {
                RCFCaster rcfCaster = this.loadRCFCaster(checkpoint, modelId);
                Entity entity = null;
                Object serializedEntity = checkpoint.get("entity");
                if (serializedEntity != null) {
                    try {
                        entity = Entity.fromJsonArray(serializedEntity);
                    }
                    catch (Exception e) {
                        logger.error((Message)new ParameterizedMessage("fail to parse entity", serializedEntity), (Throwable)e);
                    }
                }
                ModelState<RCFCaster> modelState = new ModelState<RCFCaster>(rcfCaster, modelId, configId, ModelManager.ModelType.RCFCASTER.getName(), this.clock, 0.0f, Optional.ofNullable(entity), this.loadSampleQueue(checkpoint, modelId));
                modelState.setLastCheckpointTime(this.loadTimestamp(checkpoint, modelId));
                return modelState;
            });
        }
        catch (Exception e) {
            logger.warn("Exception while deserializing checkpoint " + modelId, (Throwable)e);
            return null;
        }
    }

    public void deleteModelCheckpointByForecasterId(String forecasterId) {
        DeleteByQueryRequest deleteRequest = (DeleteByQueryRequest)((DeleteByQueryRequest)new DeleteByQueryRequest(new String[]{this.indexName}).setQuery((QueryBuilder)new MatchQueryBuilder("forecaster_id", (Object)forecasterId)).setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN).setAbortOnVersionConflict(false)).setRequestsPerSecond(500.0f);
        logger.info("Delete checkpoints of forecaster {}", (Object)forecasterId);
        this.client.execute((ActionType)DeleteByQueryAction.INSTANCE, (ActionRequest)deleteRequest, ActionListener.wrap(response -> {
            if (response.isTimedOut() || !response.getBulkFailures().isEmpty() || !response.getSearchFailures().isEmpty()) {
                this.logFailure((BulkByScrollResponse)response, forecasterId);
            }
            logger.info("{} checkpoints docs get deleted", (Object)response.getDeleted());
        }, exception -> {
            if (exception instanceof IndexNotFoundException) {
                logger.info("Checkpoint index has been deleted.  Has nothing to do: {}", (Object)forecasterId);
            } else {
                logger.error(NOT_ABLE_TO_DELETE_CHECKPOINT_MSG, (Throwable)exception);
            }
        }));
    }

    @Override
    protected DeleteByQueryRequest createDeleteCheckpointRequest(String configId) {
        return (DeleteByQueryRequest)((DeleteByQueryRequest)new DeleteByQueryRequest(new String[]{this.indexName}).setQuery((QueryBuilder)new MatchQueryBuilder("forecaster_id", (Object)configId)).setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN).setAbortOnVersionConflict(false)).setRequestsPerSecond(500.0f);
    }

    @Override
    protected ModelState<RCFCaster> fromSingleStreamModelCheckpoint(Map<String, Object> checkpoint, String modelId, String configId) {
        return AccessController.doPrivileged(() -> {
            RCFCaster rcfCaster = this.loadRCFCaster(checkpoint, modelId);
            ModelState<RCFCaster> modelState = new ModelState<RCFCaster>(rcfCaster, modelId, configId, ModelManager.ModelType.RCFCASTER.getName(), this.clock, 0.0f, Optional.empty(), this.loadSampleQueue(checkpoint, modelId));
            modelState.setLastCheckpointTime(this.loadTimestamp(checkpoint, modelId));
            return modelState;
        });
    }

    private RCFCaster loadRCFCaster(Map<String, Object> checkpoint, String modelId) {
        String model = (String)checkpoint.get("model");
        if (model == null || model.length() > this.maxCheckpointBytes) {
            logger.warn((Message)new ParameterizedMessage("[{}]'s model too large: [{}] bytes", (Object)modelId, (Object)model.length()));
            return null;
        }
        return this.toRCFCaster(model);
    }

    private Instant loadTimestamp(Map<String, Object> checkpoint, String modelId) {
        String lastCheckpointTimeString = (String)checkpoint.get("timestamp");
        return Instant.parse(lastCheckpointTimeString);
    }
}

