/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ad.ml;

import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.AbstractMap;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.core.util.Throwables;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.action.ActionListener;
import org.opensearch.action.support.ThreadedActionListener;
import org.opensearch.ad.CleanState;
import org.opensearch.ad.MaintenanceState;
import org.opensearch.ad.NodeStateManager;
import org.opensearch.ad.caching.DoorKeeper;
import org.opensearch.ad.common.exception.AnomalyDetectionException;
import org.opensearch.ad.common.exception.EndRunException;
import org.opensearch.ad.dataprocessor.Interpolator;
import org.opensearch.ad.feature.FeatureManager;
import org.opensearch.ad.feature.SearchFeatureDao;
import org.opensearch.ad.ml.EntityModel;
import org.opensearch.ad.ml.ModelState;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.ad.model.AnomalyDetectorJob;
import org.opensearch.ad.model.Entity;
import org.opensearch.ad.model.IntervalTimeConfiguration;
import org.opensearch.ad.ratelimit.CheckpointWriteWorker;
import org.opensearch.ad.ratelimit.RequestPriority;
import org.opensearch.ad.settings.AnomalyDetectorSettings;
import org.opensearch.ad.util.ExceptionUtil;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.threadpool.ThreadPool;

public class EntityColdStarter
implements MaintenanceState,
CleanState {
    private static final Logger logger = LogManager.getLogger(EntityColdStarter.class);
    private final Clock clock;
    private final ThreadPool threadPool;
    private final NodeStateManager nodeStateManager;
    private final int rcfSampleSize;
    private final int numberOfTrees;
    private final double rcfTimeDecay;
    private final int numMinSamples;
    private final double thresholdMinPvalue;
    private final int defaulStrideLength;
    private final int defaultNumberOfSamples;
    private final Interpolator interpolator;
    private final SearchFeatureDao searchFeatureDao;
    private Instant lastThrottledColdStartTime;
    private final FeatureManager featureManager;
    private int coolDownMinutes;
    private Map<String, DoorKeeper> doorKeepers;
    private final Duration modelTtl;
    private final CheckpointWriteWorker checkpointWriteQueue;
    private final long rcfSeed;
    private final int maxRoundofColdStart;
    private final double initialAcceptFraction;

    public EntityColdStarter(Clock clock, ThreadPool threadPool, NodeStateManager nodeStateManager, int rcfSampleSize, int numberOfTrees, double rcfTimeDecay, int numMinSamples, int defaultSampleStride, int defaultTrainSamples, Interpolator interpolator, SearchFeatureDao searchFeatureDao, double thresholdMinPvalue, FeatureManager featureManager, Settings settings, Duration modelTtl, CheckpointWriteWorker checkpointWriteQueue, long rcfSeed, int maxRoundofColdStart) {
        this.clock = clock;
        this.lastThrottledColdStartTime = Instant.MIN;
        this.threadPool = threadPool;
        this.nodeStateManager = nodeStateManager;
        this.rcfSampleSize = rcfSampleSize;
        this.numberOfTrees = numberOfTrees;
        this.rcfTimeDecay = rcfTimeDecay;
        this.numMinSamples = numMinSamples;
        this.defaulStrideLength = defaultSampleStride;
        this.defaultNumberOfSamples = defaultTrainSamples;
        this.interpolator = interpolator;
        this.searchFeatureDao = searchFeatureDao;
        this.thresholdMinPvalue = thresholdMinPvalue;
        this.featureManager = featureManager;
        this.coolDownMinutes = (int)((TimeValue)AnomalyDetectorSettings.COOLDOWN_MINUTES.get(settings)).getMinutes();
        this.doorKeepers = new ConcurrentHashMap<String, DoorKeeper>();
        this.modelTtl = modelTtl;
        this.checkpointWriteQueue = checkpointWriteQueue;
        this.rcfSeed = rcfSeed;
        this.maxRoundofColdStart = maxRoundofColdStart;
        this.initialAcceptFraction = (double)numMinSamples * 1.0 / (double)rcfSampleSize;
    }

    public EntityColdStarter(Clock clock, ThreadPool threadPool, NodeStateManager nodeStateManager, int rcfSampleSize, int numberOfTrees, double rcfTimeDecay, int numMinSamples, int maxSampleStride, int maxTrainSamples, Interpolator interpolator, SearchFeatureDao searchFeatureDao, double thresholdMinPvalue, FeatureManager featureManager, Settings settings, Duration modelTtl, CheckpointWriteWorker checkpointWriteQueue, int maxRoundofColdStart) {
        this(clock, threadPool, nodeStateManager, rcfSampleSize, numberOfTrees, rcfTimeDecay, numMinSamples, maxSampleStride, maxTrainSamples, interpolator, searchFeatureDao, thresholdMinPvalue, featureManager, settings, modelTtl, checkpointWriteQueue, -1L, maxRoundofColdStart);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void coldStart(String modelId, Entity entity, String detectorId, ModelState<EntityModel> modelState, AnomalyDetector detector, ActionListener<Void> listener) {
        logger.debug("Trigger cold start for {}", (Object)modelId);
        if (this.lastThrottledColdStartTime.plus(Duration.ofMinutes(this.coolDownMinutes)).isAfter(this.clock.instant())) {
            listener.onResponse(null);
            return;
        }
        boolean earlyExit = true;
        try {
            DoorKeeper doorKeeper = this.doorKeepers.computeIfAbsent(detectorId, id -> new DoorKeeper(100000L, 0.01, detector.getDetectionIntervalDuration().multipliedBy(60L), this.clock));
            if (doorKeeper.mightContain(modelId)) {
                return;
            }
            doorKeeper.put(modelId);
            ActionListener coldStartCallBack = ActionListener.wrap(trainingData -> {
                try {
                    if (trainingData.isPresent()) {
                        List dataPoints = (List)trainingData.get();
                        this.combineTrainSamples(dataPoints, modelId, modelState);
                        Queue<double[]> samples = ((EntityModel)modelState.getModel()).getSamples();
                        if (samples.size() >= this.numMinSamples) {
                            this.trainModelFromDataSegments(samples, entity, modelState, detector.getShingleSize());
                            logger.info("Succeeded in training entity: {}", (Object)modelId);
                        } else {
                            this.checkpointWriteQueue.write(modelState, true, RequestPriority.MEDIUM);
                            logger.info("Not enough data to train entity: {}, currently we have {}", (Object)modelId, (Object)samples.size());
                        }
                    } else {
                        logger.info("Cannot get training data for {}", (Object)modelId);
                    }
                    listener.onResponse(null);
                }
                catch (Exception e) {
                    listener.onFailure(e);
                }
            }, exception -> {
                try {
                    logger.error((Message)new ParameterizedMessage("Error while cold start {}", (Object)modelId), (Throwable)exception);
                    Throwable cause = Throwables.getRootCause((Throwable)exception);
                    if (ExceptionUtil.isOverloaded(cause)) {
                        logger.error("too many requests");
                        this.lastThrottledColdStartTime = Instant.now();
                    } else if (cause instanceof AnomalyDetectionException || exception instanceof AnomalyDetectionException) {
                        this.nodeStateManager.setException(detectorId, (Exception)exception);
                    } else {
                        this.nodeStateManager.setException(detectorId, new AnomalyDetectionException(detectorId, cause));
                    }
                    listener.onFailure(exception);
                }
                catch (Exception e) {
                    listener.onFailure(e);
                }
            });
            this.threadPool.executor("ad-threadpool").execute(() -> this.getEntityColdStartData(detectorId, entity, (ActionListener<Optional<List<double[][]>>>)new ThreadedActionListener(logger, this.threadPool, "ad-threadpool", coldStartCallBack, false)));
            earlyExit = false;
        }
        finally {
            if (earlyExit) {
                listener.onResponse(null);
            }
        }
    }

    private void trainModelFromDataSegments(Queue<double[]> dataPoints, Entity entity, ModelState<EntityModel> entityState, int shingleSize) {
        if (dataPoints == null || dataPoints.size() == 0) {
            throw new IllegalArgumentException("Data points must not be empty.");
        }
        double[] firstPoint = dataPoints.peek();
        if (firstPoint == null || firstPoint.length == 0) {
            throw new IllegalArgumentException("Data points must not be empty.");
        }
        int dimensions = firstPoint.length * shingleSize;
        ThresholdedRandomCutForest.Builder rcfBuilder = ThresholdedRandomCutForest.builder().dimensions(dimensions).sampleSize(this.rcfSampleSize).numberOfTrees(this.numberOfTrees).timeDecay(this.rcfTimeDecay).outputAfter(this.numMinSamples).initialAcceptFraction(this.initialAcceptFraction).parallelExecutionEnabled(false).compact(true).precision(Precision.FLOAT_32).boundingBoxCacheFraction(0.0).shingleSize(shingleSize).internalShinglingEnabled(true).anomalyRate(1.0 - this.thresholdMinPvalue);
        if (this.rcfSeed > 0L) {
            rcfBuilder.randomSeed(this.rcfSeed);
        }
        ThresholdedRandomCutForest trcf = new ThresholdedRandomCutForest(rcfBuilder);
        while (!dataPoints.isEmpty()) {
            trcf.process(dataPoints.poll(), 0L);
        }
        EntityModel model = entityState.getModel();
        if (model == null) {
            model = new EntityModel(entity, new ArrayDeque<double[]>(), null);
        }
        model.setTrcf(trcf);
        entityState.setLastUsedTime(this.clock.instant());
        this.checkpointWriteQueue.write(entityState, true, RequestPriority.MEDIUM);
    }

    private void getEntityColdStartData(String detectorId, Entity entity, ActionListener<Optional<List<double[][]>>> listener) {
        ActionListener getDetectorListener = ActionListener.wrap(detectorOp -> {
            if (!detectorOp.isPresent()) {
                listener.onFailure((Exception)new EndRunException(detectorId, "AnomalyDetector is not available.", false));
                return;
            }
            ArrayList coldStartData = new ArrayList();
            AnomalyDetector detector = (AnomalyDetector)detectorOp.get();
            ActionListener minTimeListener = ActionListener.wrap(earliest -> {
                if (earliest.isPresent()) {
                    long startTimeMs = (Long)earliest.get();
                    this.nodeStateManager.getAnomalyDetectorJob(detectorId, (ActionListener<Optional<AnomalyDetectorJob>>)ActionListener.wrap(jobOp -> {
                        if (!jobOp.isPresent()) {
                            listener.onFailure((Exception)new EndRunException(detectorId, "AnomalyDetector job is not available.", false));
                            return;
                        }
                        AnomalyDetectorJob job = (AnomalyDetectorJob)jobOp.get();
                        long endTimeMs = job.getEnabledTime().toEpochMilli();
                        Pair<Integer, Integer> params = this.selectRangeParam(detector);
                        int stride = (Integer)params.getLeft();
                        int numberOfSamples = (Integer)params.getRight();
                        this.getFeatures(listener, 0, coldStartData, detector, entity, stride, numberOfSamples, startTimeMs, endTimeMs);
                    }, arg_0 -> ((ActionListener)listener).onFailure(arg_0)));
                } else {
                    listener.onResponse(Optional.empty());
                }
            }, arg_0 -> ((ActionListener)listener).onFailure(arg_0));
            this.searchFeatureDao.getEntityMinDataTime(detector, entity, (ActionListener<Optional<Long>>)new ThreadedActionListener(logger, this.threadPool, "ad-threadpool", minTimeListener, false));
        }, arg_0 -> listener.onFailure(arg_0));
        this.nodeStateManager.getAnomalyDetector(detectorId, (ActionListener<Optional<AnomalyDetector>>)new ThreadedActionListener(logger, this.threadPool, "ad-threadpool", getDetectorListener, false));
    }

    private void getFeatures(ActionListener<Optional<List<double[][]>>> listener, int round, List<double[][]> lastRoundColdStartData, AnomalyDetector detector, Entity entity, int stride, int numberOfSamples, long startTimeMs, long endTimeMs) {
        if (startTimeMs >= endTimeMs || endTimeMs - startTimeMs < detector.getDetectorIntervalInMilliseconds()) {
            listener.onResponse(Optional.of(lastRoundColdStartData));
            return;
        }
        List<Map.Entry<Long, Long>> sampleRanges = this.getTrainSampleRanges(detector, startTimeMs, endTimeMs, stride, numberOfSamples);
        if (sampleRanges.isEmpty()) {
            listener.onResponse(Optional.of(lastRoundColdStartData));
            return;
        }
        ActionListener getFeaturelistener = ActionListener.wrap(featureSamples -> {
            Pair lastSample = null;
            ArrayList<double[][]> currentRoundColdStartData = new ArrayList<double[][]>();
            for (int i = 0; i < featureSamples.size(); ++i) {
                Optional featuresOptional = (Optional)featureSamples.get(i);
                if (!featuresOptional.isPresent()) continue;
                if (lastSample != null) {
                    int numInterpolants = (i - (Integer)lastSample.getLeft()) * stride + 1;
                    double[][] points = this.featureManager.transpose(this.interpolator.interpolate(this.featureManager.transpose(new double[][]{(double[])lastSample.getRight(), (double[])featuresOptional.get()}), numInterpolants));
                    currentRoundColdStartData.add((double[][])Arrays.copyOfRange(points, 0, points.length - 1));
                }
                lastSample = Pair.of((Object)i, (Object)((double[])featuresOptional.get()));
            }
            if (lastSample != null) {
                currentRoundColdStartData.add(new double[][]{(double[])lastSample.getRight()});
            }
            if (lastRoundColdStartData.size() > 0) {
                currentRoundColdStartData.addAll(lastRoundColdStartData);
            }
            if (this.calculateColdStartDataSize(currentRoundColdStartData) >= detector.getShingleSize() + this.numMinSamples || round + 1 >= this.maxRoundofColdStart) {
                listener.onResponse(Optional.of(currentRoundColdStartData));
            } else {
                long lastSampleStartTime = (Long)((Map.Entry)sampleRanges.get(sampleRanges.size() - 1)).getKey();
                this.getFeatures(listener, round + 1, currentRoundColdStartData, detector, entity, stride, numberOfSamples, startTimeMs, lastSampleStartTime);
            }
        }, arg_0 -> listener.onFailure(arg_0));
        try {
            this.searchFeatureDao.getColdStartSamplesForPeriods(detector, sampleRanges, entity, true, (ActionListener<List<Optional<double[]>>>)new ThreadedActionListener(logger, this.threadPool, "ad-threadpool", getFeaturelistener, false));
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    private int calculateColdStartDataSize(List<double[][]> coldStartData) {
        int size = 0;
        for (int i = 0; i < coldStartData.size(); ++i) {
            size += coldStartData.get(i).length;
        }
        return size;
    }

    private Pair<Integer, Integer> selectRangeParam(AnomalyDetector detector) {
        long delta = detector.getDetectorIntervalInMinutes();
        int shingleSize = detector.getShingleSize();
        int strideLength = this.defaulStrideLength;
        int numberOfSamples = this.defaultNumberOfSamples;
        if (delta <= 30L && 60L % delta == 0L) {
            strideLength = (int)(60L / delta);
            numberOfSamples = (int)Math.ceil((double)(shingleSize + this.numMinSamples) / 24.0) * 24;
        } else {
            strideLength = 1;
            numberOfSamples = shingleSize + this.numMinSamples;
        }
        return Pair.of((Object)strideLength, (Object)numberOfSamples);
    }

    private List<Map.Entry<Long, Long>> getTrainSampleRanges(AnomalyDetector detector, long startMilli, long endMilli, int stride, int numberOfSamples) {
        long bucketSize = ((IntervalTimeConfiguration)detector.getDetectionInterval()).toDuration().toMillis();
        int numBuckets = (int)Math.floor((double)(endMilli - startMilli) / (double)bucketSize);
        int numStrides = Math.min((int)Math.floor((double)numBuckets / (double)stride), numberOfSamples);
        List<Map.Entry<Long, Long>> sampleRanges = Stream.iterate(endMilli, i -> i - (long)stride * bucketSize).limit(numStrides).map(time -> new AbstractMap.SimpleImmutableEntry<Long, Long>(time - bucketSize, (Long)time)).collect(Collectors.toList());
        return sampleRanges;
    }

    public void trainModel(Entity entity, String detectorId, ModelState<EntityModel> modelState, ActionListener<Void> listener) {
        this.nodeStateManager.getAnomalyDetector(detectorId, (ActionListener<Optional<AnomalyDetector>>)ActionListener.wrap(detectorOptional -> {
            if (!detectorOptional.isPresent()) {
                logger.warn((Message)new ParameterizedMessage("AnomalyDetector [{}] is not available.", (Object)detectorId));
                listener.onFailure((Exception)new AnomalyDetectionException(detectorId, "fail to find detector"));
                return;
            }
            AnomalyDetector detector = (AnomalyDetector)detectorOptional.get();
            Queue<double[]> samples = ((EntityModel)modelState.getModel()).getSamples();
            String modelId = modelState.getModelId();
            if (samples.size() < this.numMinSamples) {
                this.coldStart(modelId, entity, detectorId, modelState, detector, listener);
            } else {
                try {
                    this.trainModelFromDataSegments(samples, entity, modelState, detector.getShingleSize());
                    listener.onResponse(null);
                }
                catch (Exception e) {
                    listener.onFailure(e);
                }
            }
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    public void trainModelFromExistingSamples(ModelState<EntityModel> modelState, int shingleSize) {
        if (modelState == null || modelState.getModel() == null || modelState.getModel().getSamples() == null) {
            return;
        }
        EntityModel model = modelState.getModel();
        Queue<double[]> samples = model.getSamples();
        if (samples.size() >= this.numMinSamples) {
            try {
                this.trainModelFromDataSegments(samples, model.getEntity().orElse(null), modelState, shingleSize);
            }
            catch (Exception e) {
                logger.error("Unexpected training error", (Throwable)e);
            }
        }
    }

    private void combineTrainSamples(List<double[][]> coldstartDatapoints, String modelId, ModelState<EntityModel> entityState) {
        if (coldstartDatapoints == null || coldstartDatapoints.size() == 0) {
            return;
        }
        EntityModel model = entityState.getModel();
        if (model == null) {
            model = new EntityModel(null, new ArrayDeque<double[]>(), null);
        }
        ArrayDeque<double[]> newSamples = new ArrayDeque<double[]>();
        for (double[][] consecutivePoints : coldstartDatapoints) {
            for (int i = 0; i < consecutivePoints.length; ++i) {
                newSamples.add(consecutivePoints[i]);
            }
        }
        newSamples.addAll(model.getSamples());
        model.setSamples(newSamples);
    }

    @Override
    public void maintenance() {
        this.doorKeepers.entrySet().stream().forEach(doorKeeperEntry -> {
            String detectorId = (String)doorKeeperEntry.getKey();
            DoorKeeper doorKeeper = (DoorKeeper)doorKeeperEntry.getValue();
            if (doorKeeper.expired(this.modelTtl)) {
                this.doorKeepers.remove(detectorId);
            } else {
                doorKeeper.maintenance();
            }
        });
    }

    @Override
    public void clear(String detectorId) {
        this.doorKeepers.remove(detectorId);
    }
}

