/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ad;

import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import java.util.EnumMap;
import java.util.Locale;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ad.breaker.ADCircuitBreakerService;
import org.opensearch.ad.common.exception.LimitExceededException;
import org.opensearch.ad.settings.AnomalyDetectorSettings;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.monitor.jvm.JvmService;

public class MemoryTracker {
    private static final Logger LOG = LogManager.getLogger(MemoryTracker.class);
    private long totalMemoryBytes = 0L;
    private final Map<Origin, Long> totalMemoryBytesByOrigin = new EnumMap<Origin, Long>(Origin.class);
    private long reservedMemoryBytes = 0L;
    private final Map<Origin, Long> reservedMemoryBytesByOrigin = new EnumMap<Origin, Long>(Origin.class);
    private long heapSize;
    private long heapLimitBytes;
    private long desiredModelSize;
    private int thresholdModelBytes;
    private ADCircuitBreakerService adCircuitBreakerService;

    public MemoryTracker(JvmService jvmService, double modelMaxSizePercentage, double modelDesiredSizePercentage, ClusterService clusterService, ADCircuitBreakerService adCircuitBreakerService) {
        this.heapSize = jvmService.info().getMem().getHeapMax().getBytes();
        this.heapLimitBytes = (long)((double)this.heapSize * modelMaxSizePercentage);
        this.desiredModelSize = (long)((double)this.heapSize * modelDesiredSizePercentage);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, it -> {
            this.heapLimitBytes = (long)((double)this.heapSize * it);
        });
        this.thresholdModelBytes = 180000;
        this.adCircuitBreakerService = adCircuitBreakerService;
    }

    public synchronized boolean isHostingAllowed(String detectorId, ThresholdedRandomCutForest trcf) {
        long requiredBytes = this.estimateTRCFModelSize(trcf);
        if (this.canAllocateReserved(requiredBytes)) {
            return true;
        }
        throw new LimitExceededException(detectorId, String.format(Locale.ROOT, "Exceeded memory limit. New size is %d bytes and max limit is %d bytes", this.reservedMemoryBytes + requiredBytes, this.heapLimitBytes));
    }

    public synchronized boolean canAllocateReserved(long requiredBytes) {
        return false == this.adCircuitBreakerService.isOpen() && this.reservedMemoryBytes + requiredBytes <= this.heapLimitBytes;
    }

    public synchronized boolean canAllocate(long bytes) {
        return false == this.adCircuitBreakerService.isOpen() && this.totalMemoryBytes + bytes <= this.heapLimitBytes;
    }

    public synchronized void consumeMemory(long memoryToConsume, boolean reserved, Origin origin) {
        this.totalMemoryBytes += memoryToConsume;
        this.adjustOriginMemoryConsumption(memoryToConsume, origin, this.totalMemoryBytesByOrigin);
        if (reserved) {
            this.reservedMemoryBytes += memoryToConsume;
            this.adjustOriginMemoryConsumption(memoryToConsume, origin, this.reservedMemoryBytesByOrigin);
        }
    }

    private void adjustOriginMemoryConsumption(long memoryToConsume, Origin origin, Map<Origin, Long> mapToUpdate) {
        Long originTotalMemoryBytes = mapToUpdate.getOrDefault((Object)origin, 0L);
        mapToUpdate.put(origin, originTotalMemoryBytes + memoryToConsume);
    }

    public synchronized void releaseMemory(long memoryToShed, boolean reserved, Origin origin) {
        this.totalMemoryBytes -= memoryToShed;
        this.adjustOriginMemoryRelease(memoryToShed, origin, this.totalMemoryBytesByOrigin);
        if (reserved) {
            this.reservedMemoryBytes -= memoryToShed;
            this.adjustOriginMemoryRelease(memoryToShed, origin, this.reservedMemoryBytesByOrigin);
        }
    }

    private void adjustOriginMemoryRelease(long memoryToConsume, Origin origin, Map<Origin, Long> mapToUpdate) {
        Long originTotalMemoryBytes = mapToUpdate.get((Object)origin);
        if (originTotalMemoryBytes != null) {
            mapToUpdate.put(origin, originTotalMemoryBytes - memoryToConsume);
        }
    }

    public long estimateTRCFModelSize(ThresholdedRandomCutForest trcf) {
        RandomCutForest forest = trcf.getForest();
        return this.estimateTRCFModelSize(forest.getDimensions(), forest.getNumberOfTrees(), forest.getBoundingBoxCacheFraction(), forest.getShingleSize(), forest.isInternalShinglingEnabled());
    }

    public long estimateTRCFModelSize(int dimension, int numberOfTrees, double boundingBoxCacheFraction, int shingleSize, boolean internalShingling) {
        double averagePointStoreUsage = 0.0;
        if (!internalShingling || shingleSize == 1) {
            averagePointStoreUsage = 1.0;
        } else if (shingleSize <= 3) {
            averagePointStoreUsage = 0.53;
        } else if (shingleSize <= 12) {
            averagePointStoreUsage = 0.27;
        } else if (shingleSize <= 24) {
            averagePointStoreUsage = 0.13;
        } else if (shingleSize <= 64) {
            averagePointStoreUsage = 0.07;
        } else {
            throw new IllegalArgumentException("out of range shingle size " + shingleSize);
        }
        double actualBoundingBoxUsage = boundingBoxCacheFraction >= 0.3 ? 1.0 : boundingBoxCacheFraction;
        long compactRcfSize = (long)(56.0 + (double)numberOfTrees * (8624.0 + (double)(1040 + 255 * (dimension * 8 + 64)) * actualBoundingBoxUsage) + (double)(256 * numberOfTrees * dimension * 4) * averagePointStoreUsage + 77192.0);
        long thresholdSize = 6 * (dimension * 8 + 16) + shingleSize * 8 + 624;
        return compactRcfSize + thresholdSize;
    }

    public synchronized long memoryToShed() {
        return this.totalMemoryBytes - this.heapLimitBytes;
    }

    public long getHeapLimit() {
        return this.heapLimitBytes;
    }

    public long getDesiredModelSize() {
        return this.desiredModelSize;
    }

    public long getTotalMemoryBytes() {
        return this.totalMemoryBytes;
    }

    public synchronized boolean syncMemoryState(Origin origin, long totalBytes, long reservedBytes) {
        long recordedTotalBytes = this.totalMemoryBytesByOrigin.getOrDefault((Object)origin, 0L);
        long recordedReservedBytes = this.reservedMemoryBytesByOrigin.getOrDefault((Object)origin, 0L);
        if (totalBytes == recordedTotalBytes && reservedBytes == recordedReservedBytes) {
            return false;
        }
        LOG.info(String.format(Locale.ROOT, "Memory states do not match.  Recorded: total bytes %d, reserved bytes %d.Actual: total bytes %d, reserved bytes: %d", recordedTotalBytes, recordedReservedBytes, totalBytes, reservedBytes));
        long reservedDiff = reservedBytes - recordedReservedBytes;
        this.reservedMemoryBytesByOrigin.put(origin, reservedBytes);
        this.reservedMemoryBytes += reservedDiff;
        long totalDiff = totalBytes - recordedTotalBytes;
        this.totalMemoryBytesByOrigin.put(origin, totalBytes);
        this.totalMemoryBytes += totalDiff;
        return true;
    }

    public int getThresholdModelBytes() {
        return this.thresholdModelBytes;
    }

    public static enum Origin {
        SINGLE_ENTITY_DETECTOR,
        HC_DETECTOR,
        HISTORICAL_SINGLE_ENTITY_DETECTOR;

    }
}

