/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.randomcutforest.imputation;

import com.amazon.randomcutforest.CommonUtils;
import com.amazon.randomcutforest.returntypes.ConditionalTreeSample;
import com.amazon.randomcutforest.returntypes.SampleSummary;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class ConditionalSampleSummarizer {
    public static double SEPARATION_RATIO_FOR_MERGE = 0.8;
    public static double WEIGHT_ALLOCATION_THRESHOLD = 1.25;
    public static int MAX_NUMBER_OF_TYPICAL_PER_DIMENSION = 2;
    public static int MAX_NUMBER_OF_TYPICAL_ELEMENTS = 5;
    protected int[] missingDimensions;
    protected float[] queryPoint;
    protected double centrality;

    public ConditionalSampleSummarizer(int[] missingDimensions, float[] queryPoint, double centrality) {
        this.missingDimensions = Arrays.copyOf(missingDimensions, missingDimensions.length);
        this.queryPoint = Arrays.copyOf(queryPoint, queryPoint.length);
        this.centrality = centrality;
    }

    public SampleSummary summarize(List<ConditionalTreeSample> alist) {
        int i;
        int j;
        int num;
        CommonUtils.checkArgument(alist.size() > 0, "incorrect call to summarize");
        double totalWeight = alist.size();
        List<ConditionalTreeSample> newList = ConditionalTreeSample.dedup(alist);
        newList.sort((o1, o2) -> Double.compare(o1.distance, o2.distance));
        double threshold = 0.0;
        double currentWeight = 0.0;
        double remainderWeight = totalWeight;
        for (int alwaysInclude = 0; alwaysInclude < newList.size() && newList.get((int)alwaysInclude).distance == 0.0; ++alwaysInclude) {
            remainderWeight -= newList.get((int)alwaysInclude).weight;
        }
        for (int j2 = 1; j2 < newList.size(); ++j2) {
            if (currentWeight < remainderWeight / 3.0 && currentWeight + newList.get((int)j2).weight >= remainderWeight / 3.0 || currentWeight < remainderWeight / 2.0 && currentWeight + newList.get((int)j2).weight >= remainderWeight / 2.0) {
                threshold += this.centrality * newList.get((int)j2).distance;
            }
            currentWeight += newList.get((int)j2).weight;
        }
        threshold += (1.0 - this.centrality) * newList.get((int)(newList.size() - 1)).distance;
        for (num = 0; num < newList.size() && newList.get((int)num).distance <= threshold; ++num) {
        }
        float[] coordMean = new float[this.queryPoint.length];
        double[] coordSqSum = new double[this.queryPoint.length];
        Center center = new Center(this.missingDimensions.length);
        ProjectedPoint[] points = new ProjectedPoint[num];
        for (int j3 = 0; j3 < newList.size(); ++j3) {
            int i2;
            ConditionalTreeSample e = newList.get(j3);
            float[] values = new float[this.missingDimensions.length];
            for (i2 = 0; i2 < this.missingDimensions.length; ++i2) {
                values[i2] = e.leafPoint[this.missingDimensions[i2]];
            }
            center.add(values, e.weight);
            for (i2 = 0; i2 < coordMean.length; ++i2) {
                int n = i2;
                coordMean[n] = (float)((double)coordMean[n] + (double)e.leafPoint[i2] * e.weight);
                int n2 = i2;
                coordSqSum[n2] = coordSqSum[n2] + (double)(e.leafPoint[i2] * e.leafPoint[i2]) * e.weight;
            }
            if (j3 >= num) continue;
            double weight = e.distance <= threshold ? e.weight : e.weight * threshold / e.distance;
            points[j3] = new ProjectedPoint(values, weight);
        }
        float[] median = Arrays.copyOf(this.queryPoint, this.queryPoint.length);
        center.recompute();
        for (int y = 0; y < this.missingDimensions.length; ++y) {
            median[this.missingDimensions[y]] = center.coordinate[y];
        }
        float[] deviation = new float[this.queryPoint.length];
        for (int j4 = 0; j4 < coordMean.length; ++j4) {
            coordMean[j4] = coordMean[j4] / (float)totalWeight;
            deviation[j4] = (float)Math.sqrt(Math.max(0.0, coordSqSum[j4] / totalWeight - (double)(coordMean[j4] * coordMean[j4])));
        }
        ArrayList<Center> centers = new ArrayList<Center>();
        centers.add(new Center(center.coordinate));
        int maxAllowed = Math.min(center.coordinate.length * MAX_NUMBER_OF_TYPICAL_PER_DIMENSION, MAX_NUMBER_OF_TYPICAL_ELEMENTS);
        for (int k = 0; k < 2 * maxAllowed; ++k) {
            double maxDist = 0.0;
            int maxIndex = -1;
            for (j = 0; j < points.length; ++j) {
                double minDist = Double.MAX_VALUE;
                for (int i3 = 0; i3 < centers.size(); ++i3) {
                    minDist = Math.min(minDist, ConditionalSampleSummarizer.distance(points[j], (ProjectedPoint)centers.get(i3)));
                }
                if (!(minDist > maxDist)) continue;
                maxDist = minDist;
                maxIndex = j;
            }
            if (maxDist == 0.0) break;
            centers.add(new Center(Arrays.copyOf(points[maxIndex].coordinate, points[maxIndex].coordinate.length)));
        }
        double measure = 10.0;
        do {
            for (int i4 = 0; i4 < centers.size(); ++i4) {
                ((Center)centers.get(i4)).reset();
            }
            double maxDist = 0.0;
            for (j = 0; j < points.length; ++j) {
                int i5;
                int i6;
                double[] dist = new double[centers.size()];
                Arrays.fill(dist, Double.MAX_VALUE);
                double minDist = Double.MAX_VALUE;
                for (i6 = 0; i6 < centers.size(); ++i6) {
                    dist[i6] = ConditionalSampleSummarizer.distance(points[j], (ProjectedPoint)centers.get(i6));
                    minDist = Math.min(minDist, dist[i6]);
                }
                if (minDist == 0.0) {
                    for (i6 = 0; i6 < centers.size(); ++i6) {
                        if (dist[i6] != 0.0) continue;
                        ((Center)centers.get(i6)).add(points[j].coordinate, points[j].weight);
                    }
                    continue;
                }
                maxDist = Math.max(maxDist, minDist);
                double sum = 0.0;
                for (i5 = 0; i5 < centers.size(); ++i5) {
                    if (!(dist[i5] <= WEIGHT_ALLOCATION_THRESHOLD * minDist)) continue;
                    sum += minDist / dist[i5];
                }
                for (i5 = 0; i5 < centers.size(); ++i5) {
                    if (dist[i5] == 0.0) {
                        ((Center)centers.get(i5)).add(points[j].coordinate, points[j].weight);
                        continue;
                    }
                    if (!(dist[i5] <= WEIGHT_ALLOCATION_THRESHOLD * minDist)) continue;
                    ((Center)centers.get(i5)).add(points[j].coordinate, points[j].weight * minDist / (dist[i5] * sum));
                }
            }
            for (i = 0; i < centers.size(); ++i) {
                ((Center)centers.get(i)).recompute();
            }
            int first = -1;
            int second = -1;
            measure = 0.0;
            for (int i7 = 0; i7 < centers.size(); ++i7) {
                for (int j5 = i7 + 1; j5 < centers.size(); ++j5) {
                    double dist = ConditionalSampleSummarizer.distance((ProjectedPoint)centers.get(i7), (ProjectedPoint)centers.get(j5));
                    double tempMeasure = (((Center)centers.get(i7)).radius() + ((Center)centers.get(j5)).radius()) / dist;
                    if (!(measure < tempMeasure)) continue;
                    first = i7;
                    second = j5;
                    measure = tempMeasure;
                }
            }
            if (measure >= SEPARATION_RATIO_FOR_MERGE) {
                if (((Center)centers.get((int)first)).weight < ((Center)centers.get((int)second)).weight) {
                    centers.remove(first);
                    continue;
                }
                centers.remove(second);
                continue;
            }
            if (centers.size() <= maxAllowed) continue;
            centers.sort((o1, o2) -> Double.compare(o1.weight, o2.weight));
            centers.remove(0);
        } while (centers.size() > maxAllowed || measure >= SEPARATION_RATIO_FOR_MERGE);
        centers.sort((o1, o2) -> Double.compare(o2.weight, o1.weight));
        float[][] pointList = new float[centers.size()][];
        float[] likelihood = new float[centers.size()];
        for (i = 0; i < centers.size(); ++i) {
            pointList[i] = Arrays.copyOf(this.queryPoint, this.queryPoint.length);
            for (int j6 = 0; j6 < this.missingDimensions.length; ++j6) {
                pointList[i][this.missingDimensions[j6]] = ((Center)centers.get((int)i)).coordinate[j6];
            }
            likelihood[i] = (float)(((Center)centers.get((int)i)).weight / totalWeight);
        }
        return new SampleSummary(totalWeight, pointList, likelihood, median, coordMean, deviation);
    }

    static double distance(ProjectedPoint a, ProjectedPoint b) {
        double distance = 0.0;
        for (int i = 0; i < a.coordinate.length; ++i) {
            distance += (double)Math.abs(a.coordinate[i] - b.coordinate[i]);
        }
        return distance;
    }

    public void assign(ProjectedPoint[] points, List<Center> centers) {
        centers.stream().forEach(x -> {
            x.weight = 0.0;
        });
        for (int i = 0; i < points.length; ++i) {
            int j;
            double[] distance = new double[centers.size()];
            double minDistance = Double.MAX_VALUE;
            for (int j2 = 0; j2 < centers.size(); ++j2) {
                distance[j2] = ConditionalSampleSummarizer.distance(centers.get(j2), points[i]);
                minDistance = Math.min(minDistance, distance[j2]);
            }
            double sum = 0.0;
            for (j = 0; j < centers.size(); ++j) {
                if (!(distance[j] <= 1.25 * minDistance)) continue;
                if (distance[j] > 0.0) {
                    sum += minDistance / distance[j];
                    continue;
                }
                sum += 1.0;
            }
            for (j = 0; j < centers.size(); ++j) {
                if (!(distance[j] <= 1.25 * minDistance)) continue;
                if (distance[j] == 0.0) {
                    centers.get(j).add(points[i].coordinate, 1.0 / sum);
                    continue;
                }
                centers.get(j).add(points[i].coordinate, minDistance / (sum * distance[j]));
            }
        }
    }

    class Center
    extends ProjectedPoint {
        ArrayList<ProjectedPoint> points;
        double sumOfRadius;

        Center(int dimensions) {
            super(new float[dimensions], 0.0);
            this.points = new ArrayList();
        }

        Center(float[] coordinate) {
            super(Arrays.copyOf(coordinate, coordinate.length), 0.0);
            this.points = new ArrayList();
        }

        public void add(float[] coordinate, double weight) {
            this.points.add(new ProjectedPoint(coordinate, weight));
            this.weight += weight;
        }

        public void reset() {
            this.points = new ArrayList();
            this.weight = 0.0;
        }

        public double radius() {
            return this.weight > 0.0 ? this.sumOfRadius / this.weight : 0.0;
        }

        public void recompute() {
            this.sumOfRadius = 0.0;
            if (this.weight == 0.0) {
                CommonUtils.checkArgument(this.points.size() == 0, "adding 0 weight points?");
                Arrays.fill(this.coordinate, 0.0f);
                return;
            }
            for (int i = 0; i < this.coordinate.length; ++i) {
                int position;
                int index = i;
                this.points.sort((o1, o2) -> Double.compare(o1.coordinate[index], o2.coordinate[index]));
                double runningWeight = this.weight / 2.0;
                for (position = 0; runningWeight >= 0.0 && position < this.points.size() && runningWeight >= this.points.get((int)position).weight; runningWeight -= this.points.get((int)position).weight, ++position) {
                }
                this.coordinate[index] = this.points.get((int)position).coordinate[index];
                for (int j = 0; j < this.points.size(); ++j) {
                    this.sumOfRadius += this.points.get((int)j).weight * (double)Math.abs(this.coordinate[index] - this.points.get((int)j).coordinate[index]);
                }
            }
        }
    }

    class ProjectedPoint {
        final float[] coordinate;
        double weight;

        ProjectedPoint(float[] coordinate, double weight) {
            this.coordinate = coordinate;
            this.weight = weight;
        }
    }
}

